Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya S Khudia <dskhudia@fb.com>2018-10-13 01:48:13 +0300
committerDaya S Khudia <dskhudia@fb.com>2018-10-31 00:56:00 +0300
commite85b5a12254fa47ca6b56236489253a68fd32104 (patch)
treed62190c53913c65e136fb26dc89bfab38144e2c3
Initial commit
-rw-r--r--CMakeLists.txt171
-rw-r--r--CODE_OF_CONDUCT.md5
-rw-r--r--CONTRIBUTING.md34
-rw-r--r--LICENSE30
-rw-r--r--README.md93
-rw-r--r--bench/AlignedVec.h124
-rw-r--r--bench/BenchUtils.cc44
-rw-r--r--bench/BenchUtils.h18
-rw-r--r--bench/CMakeLists.txt43
-rw-r--r--bench/Depthwise3DBenchmark.cc265
-rw-r--r--bench/DepthwiseBenchmark.cc342
-rw-r--r--bench/FP16Benchmark.cc215
-rw-r--r--bench/I8SpmdmBenchmark.cc212
-rw-r--r--bench/Im2ColFusedRequantizeAcc16Benchmark.cc241
-rw-r--r--bench/Im2ColFusedRequantizeAcc32Benchmark.cc242
-rw-r--r--bench/PackedFloatInOutBenchmark.cc269
-rw-r--r--bench/PackedRequantizeAcc16Benchmark.cc439
-rw-r--r--bench/PackedRequantizeAcc32Benchmark.cc329
-rw-r--r--cmake/modules/DownloadASMJIT.cmake16
-rw-r--r--cmake/modules/DownloadCPUINFO.cmake16
-rw-r--r--cmake/modules/DownloadGTEST.cmake16
-rw-r--r--cmake/modules/FindMKL.cmake270
-rw-r--r--include/fbgemm/ConvUtils.h95
-rw-r--r--include/fbgemm/Fbgemm.h952
-rw-r--r--include/fbgemm/FbgemmFP16.h160
-rw-r--r--include/fbgemm/FbgemmI8Spmdm.h101
-rw-r--r--include/fbgemm/OutputProcessing-inl.h356
-rw-r--r--include/fbgemm/PackingTraits-inl.h150
-rw-r--r--include/fbgemm/Types.h115
-rw-r--r--include/fbgemm/Utils.h123
-rw-r--r--src/ExecuteKernel.cc12
-rw-r--r--src/ExecuteKernel.h11
-rw-r--r--src/ExecuteKernelGeneric.h64
-rw-r--r--src/ExecuteKernelU8S8.cc354
-rw-r--r--src/ExecuteKernelU8S8.h73
-rw-r--r--src/Fbgemm.cc363
-rw-r--r--src/FbgemmFP16.cc293
-rw-r--r--src/FbgemmFP16UKernels.cc2203
-rw-r--r--src/FbgemmFP16UKernels.h40
-rw-r--r--src/FbgemmI8Depthwise.cc1953
-rw-r--r--src/FbgemmI8Depthwise.h105
-rw-r--r--src/FbgemmI8Spmdm.cc508
-rw-r--r--src/GenerateKernel.h154
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc292
-rw-r--r--src/GenerateKernelU8S8S32ACC16_avx512.cc295
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc310
-rw-r--r--src/GenerateKernelU8S8S32ACC32_avx512.cc312
-rw-r--r--src/PackAMatrix.cc165
-rw-r--r--src/PackAWithIm2Col.cc146
-rw-r--r--src/PackBMatrix.cc144
-rw-r--r--src/PackMatrix.cc86
-rw-r--r--src/PackWithQuantRowOffset.cc230
-rw-r--r--src/PackWithRowOffset.cc211
-rw-r--r--src/RefImplementations.cc608
-rw-r--r--src/RefImplementations.h268
-rw-r--r--src/Utils.cc357
-rw-r--r--src/Utils_avx512.cc243
-rw-r--r--src/codegen_fp16fp32.cc387
-rw-r--r--test/CMakeLists.txt47
-rw-r--r--test/FP16Test.cc124
-rw-r--r--test/I8DepthwiseTest.cc448
-rw-r--r--test/I8DepthwiseTest.h38
-rw-r--r--test/I8SpmdmTest.cc158
-rw-r--r--test/PackedRequantizeAcc16Test.cc535
-rw-r--r--test/PackedRequantizeTest.cc625
-rw-r--r--test/QuantizationHelpers.cc57
-rw-r--r--test/QuantizationHelpers.h18
-rw-r--r--test/TestUtils.cc100
-rw-r--r--test/TestUtils.h40
69 files changed, 17863 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..cfc47b5
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,171 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
+
+#install libraries into correct locations on all platforms
+include(GNUInstallDirs)
+
+project(fbgemm VERSION 0.1 LANGUAGES CXX C)
+
+set(FBGEMM_LIBRARY_TYPE "default" CACHE STRING
+ "Type of library (shared, static, or default) to build")
+set_property(CACHE FBGEMM_LIBRARY_TYPE PROPERTY STRINGS default static shared)
+option(FBGEMM_BUILD_TESTS "Build fbgemm unit tests" ON)
+option(FBGEMM_BUILD_BENCHMARKS "Build fbgemm benchmarks" ON)
+
+if(FBGEMM_BUILD_TESTS)
+ enable_testing()
+endif()
+
+set(FBGEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_SOURCE_DIR}/third-party)
+set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+#All the source files that either use avx2 instructions statically or JIT
+#avx2/avx512 instructions.
+set(FBGEMM_AVX2_SRCS src/ExecuteKernel.cc
+ src/ExecuteKernelU8S8.cc
+ src/Fbgemm.cc
+ src/FbgemmFP16.cc
+ src/FbgemmFP16UKernels.cc
+ src/FbgemmI8Depthwise.cc
+ src/FbgemmI8Spmdm.cc
+ src/GenerateKernelU8S8S32ACC16.cc
+ src/GenerateKernelU8S8S32ACC16_avx512.cc
+ src/GenerateKernelU8S8S32ACC32.cc
+ src/GenerateKernelU8S8S32ACC32_avx512.cc
+ src/PackAMatrix.cc
+ src/PackAWithIm2Col.cc
+ src/PackBMatrix.cc
+ src/PackMatrix.cc
+ src/PackWithQuantRowOffset.cc
+ src/PackWithRowOffset.cc
+ src/RefImplementations.cc
+ src/Utils.cc)
+
+#check if compiler supports avx512
+include(CheckCXXCompilerFlag)
+CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512)
+if(NOT COMPILER_SUPPORTS_AVX512)
+ message(FATAL_ERROR "A compiler with AVX512 support is required.")
+endif()
+
+#All the source files that use avx512 instructions statically
+set(FBGEMM_AVX512_SRCS src/Utils_avx512.cc)
+
+set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
+ include/fbgemm/OutputProcessing-inl.h
+ include/fbgemm/PackingTraits-inl.h
+ include/fbgemm/Utils.h
+ include/fbgemm/ConvUtils.h
+ include/fbgemm/Types.h
+ include/fbgemm/FbgemmI8Spmdm.h)
+
+
+add_library(fbgemm_avx2 OBJECT ${FBGEMM_AVX2_SRCS})
+add_library(fbgemm_avx512 OBJECT ${FBGEMM_AVX512_SRCS})
+
+set_target_properties(fbgemm_avx2 fbgemm_avx512 PROPERTIES
+ CXX_STANDARD 11
+ CXX_EXTENSIONS NO)
+
+target_compile_options(fbgemm_avx2 PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel")
+target_compile_options(fbgemm_avx512 PRIVATE
+ "-m64" "-mavx2" "-mfma" "-mavx512f" "-masm=intel")
+
+if(NOT TARGET asmjit)
+ #Download asmjit from github if ASMJIT_SRC_DIR is not specified.
+ if(NOT DEFINED ASMJIT_SRC_DIR)
+ message(STATUS "Downloading asmjit to ${FBGEMM_THIRDPARTY_DIR}/asmjit
+ (define ASMJIT_SRC_DIR to avoid it)")
+ configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadASMJIT.cmake"
+ "${FBGEMM_BINARY_DIR}/asmjit-download/CMakeLists.txt")
+ execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/asmjit-download")
+ execute_process(COMMAND "${CMAKE_COMMAND}" --build .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/asmjit-download")
+ set(ASMJIT_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/asmjit" CACHE STRING
+ "asmjit source directory")
+ endif()
+
+ #build asmjit
+ set(ASMJIT_STATIC ON)
+ add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit")
+endif()
+
+if(NOT TARGET cpuinfo)
+ #Download cpuinfo from github if CPUINFO_SRC_DIR is not specified.
+ if(NOT DEFINED CPUINFO_SRC_DIR)
+ message(STATUS "Downloading cpuinfo to ${FBGEMM_THIRDPARTY_DIR}/cpuinfo
+ (define CPUINFO_SRC_DIR to avoid it)")
+ configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadCPUINFO.cmake"
+ "${FBGEMM_BINARY_DIR}/cpuinfo-download/CMakeLists.txt")
+ execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download")
+ execute_process(COMMAND "${CMAKE_COMMAND}" --build .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download")
+ set(CPUINFO_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING
+ "cpuinfo source directory")
+ endif()
+
+ #build cpuinfo
+ set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "Do not build cpuinfo unit tests")
+ set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests")
+ set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks")
+ set(CPUINFO_LIBRARY_TYPE static)
+ add_subdirectory("${CPUINFO_SRC_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo")
+ set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
+endif()
+
+target_include_directories(fbgemm_avx2 BEFORE
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
+ PRIVATE "${ASMJIT_SRC_DIR}/src"
+ PRIVATE "${CPUINFO_SRC_DIR}/include")
+
+target_include_directories(fbgemm_avx512 BEFORE
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
+ PRIVATE "${ASMJIT_SRC_DIR}/src"
+ PRIVATE "${CPUINFO_SRC_DIR}/include")
+
+if(FBGEMM_LIBRARY_TYPE STREQUAL "default")
+ add_library(fbgemm $<TARGET_OBJECTS:fbgemm_avx2>
+ $<TARGET_OBJECTS:fbgemm_avx512>)
+elseif(FBGEMM_LIBRARY_TYPE STREQUAL "shared")
+ add_library(fbgemm SHARED $<TARGET_OBJECTS:fbgemm_avx2>
+ $<TARGET_OBJECTS:fbgemm_avx512>)
+elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static")
+ add_library(fbgemm STATIC $<TARGET_OBJECTS:fbgemm_avx2>
+ $<TARGET_OBJECTS:fbgemm_avx512>)
+else()
+ message(FATAL_ERROR "Unsupported library type ${FBGEMM_LIBRARY_TYPE}")
+endif()
+
+target_include_directories(fbgemm BEFORE
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
+ PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>)
+
+target_link_libraries(fbgemm asmjit cpuinfo)
+add_dependencies(fbgemm asmjit cpuinfo)
+
+install(TARGETS fbgemm EXPORT fbgemmLibraryConfig
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
+
+install(FILES ${FBGEMM_PUBLIC_HEADERS}
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/fbgemm")
+
+#Make project importable from the build directory
+#export(TARGETS fbgemm asmjit FILE fbgemmLibraryConfig.cmake)
+
+if(FBGEMM_BUILD_TESTS)
+ add_subdirectory(test)
+endif()
+
+if(FBGEMM_BUILD_BENCHMARKS)
+ add_subdirectory(bench)
+endif()
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..0f7ad8b
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,5 @@
+# Code of Conduct
+
+Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
+Please read the [full text](https://code.fb.com/codeofconduct/)
+so that you can understand what actions will and will not be tolerated.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..042f9e4
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,34 @@
+# Contributing to FBGEMM
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Code of Conduct
+The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md).
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `master`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here: <https://code.facebook.com/cla>
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to FBGEMM, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..8299579
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,30 @@
+BSD License
+
+For FBGEMM software
+
+Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name Facebook nor the names of its contributors may be used to
+ endorse or promote products derived from this software without specific
+ prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..2b7efaa
--- /dev/null
+++ b/README.md
@@ -0,0 +1,93 @@
+# FBGEMM
+FBGEMM (Facebook GEneral Matrix Multiplication) is a low-precision,
+high-performance matrix-matrix multiplications and convolution library for
+server-side inference.
+
+The library provides efficient low-precision general matrix multiplication for
+small batch sizes and support for accuracy-loss minimizing techniques such as
+row-wise quantization and outlier-aware quantization. FBGEMM also exploits
+fusion opportunities in order to overcome the unique challenges of matrix
+multiplication at lower precision with bandwidth-bound operations.
+
+## Examples
+
+The tests (in test folder) and benchmarks (in bench folder) are some great
+examples of using FBGEMM. For instance, SpMDMTest test in
+test/PackedRequantizeAcc16Test.cc shows how to combine row offset calculations
+with packing of A (PackAWithRowOffset), how to pack B matrix (PackBMatrix) and
+construct output pipeline (sparse\_matrix\*dense\_matrix --> requantization -->
+nop) fused with inner GEMM macro kernel.
+
+## Build Notes
+FBGEMM uses the standard CMAKE-based build flow.
+
+### Dependencies
+FBGEMM requires gcc 4.9+ and a CPU with support for avx2 instruction set or
+higher. It's been tested on Mac OS X and Linux.
+
++ ###### asmjit
+With inner kernels, FBGEMM takes a “one size doesn't fit all” approach, so the
+implementation dynamically generates efficient matrix-shape specific vectorized
+code using a third-party library called [asmjit][1]. **asmjit is required** to
+build FBGEMM.
+
++ ###### cpuinfo
+FBGEMM detects CPU instruction set support at runtime using cpuinfo library and
+dispatches optimized kernels for the detected instruction set. Therefore,
+**cpuinfo is required** to detect CPU type.
+
++ ###### googletest
+googletest is required to build and run FBGEMM's tests. **googletest is not
+required** if you don't want to run FBGEMM tests. By default, building of tests
+is **on**. Turn it off by setting FBGEMM\_BUILD\_TESTS to off.
+
+You can download [asmjit][1], [cpuinfo][2], [googletest][3] and set
+ASMJIT\_SRC\_DIR, CPUINFO\_SRC\_DIR, GOOGLETEST\_SOURCE\_DIR respectively for
+cmake to find these libraries. If any of these variables is not set, cmake will
+try to download that missing library in a folder called third-party in the
+current directory and build it using the downloaded source code.
+
+FBGEMM, in general, does not have any dependency on Intel MKL. However, for
+performance comparison, some benchmarks use MKL functions. If MKL is found or
+MKL path is provided with INTEL\_MKL\_DIR benchmarks are built with MKL and
+performance numbers are reported for MKL functions as well. However, if MKL is
+not found, the benchmarks are not built.
+
+General build instructions are as follows:
+
+```
+mkdir build && cd build
+cmake ..
+make
+```
+
+To run the tests after building FBGEMM (if tests are built), use the following
+command:
+```
+make test
+```
+
+## Installing FBGEMM
+```
+make install
+```
+
+## How FBGEMM works
+For a high-level overview, design philosophy and brief descriptions of various
+parts of FBGEMM please see [our blog][4].
+
+## Full documentation
+We have extensively used comments in our source files. The best and up-do-date
+documentation is available in the source files.
+
+## Join the FBGEMM community
+See the [`CONTRIBUTING`](CONTRIBUTING.md) file for how to help out.
+
+## License
+FBGEMM is BSD licensed, as found in the [`LICENSE`](LICENSE) file.
+
+
+[1]:https://github.com/asmjit/asmjit
+[2]:https://github.com/pytorch/cpuinfo
+[3]:https://github.com/google/googletest
+[4]:https://code.fb.com/ai-research/
diff --git a/bench/AlignedVec.h b/bench/AlignedVec.h
new file mode 100644
index 0000000..30fd266
--- /dev/null
+++ b/bench/AlignedVec.h
@@ -0,0 +1,124 @@
+#pragma once
+
+#include <cassert>
+#include <cstdlib>
+#include <stdexcept>
+#include <vector>
+
+/**
+ * Allocator for aligned data.
+ *
+ * Modified from the Mallocator from Stephan T. Lavavej.
+ * <http://blogs.msdn.com/b/vcblog/archive/2008/08/28/the-mallocator.aspx>
+ *
+ */
+template <typename T, std::size_t Alignment> class aligned_allocator {
+public:
+ // The following will be the same for virtually all allocators.
+ typedef T *pointer;
+ typedef const T *const_pointer;
+ typedef T &reference;
+ typedef const T &const_reference;
+ typedef T value_type;
+ typedef std::size_t size_type;
+ typedef std::ptrdiff_t difference_type;
+
+ T *address(T &r) const { return &r; }
+
+ const T *address(const T &s) const { return &s; }
+
+ std::size_t max_size() const {
+ // The following has been carefully written to be independent of
+ // the definition of size_t and to avoid signed/unsigned warnings.
+ return (static_cast<std::size_t>(0) - static_cast<std::size_t>(1)) /
+ sizeof(T);
+ }
+
+ // The following must be the same for all allocators.
+ template <typename U> struct rebind {
+ typedef aligned_allocator<U, Alignment> other;
+ };
+
+ bool operator!=(const aligned_allocator &other) const {
+ return !(*this == other);
+ }
+
+ void construct(T *const p, const T &t) const {
+ void *const pv = static_cast<void *>(p);
+
+ new (pv) T(t);
+ }
+
+ void destroy(T *const p) const { p->~T(); }
+
+ // Returns true if and only if storage allocated from *this
+ // can be deallocated from other, and vice versa.
+ // Always returns true for stateless allocators.
+ bool operator==(const aligned_allocator & /*other*/) const { return true; }
+
+ // Default constructor, copy constructor, rebinding constructor, and
+ // destructor. Empty for stateless allocators.
+ aligned_allocator() {}
+
+ aligned_allocator(const aligned_allocator &) {}
+
+ template <typename U>
+ aligned_allocator(const aligned_allocator<U, Alignment> &) {}
+
+ ~aligned_allocator() {}
+
+ // The following will be different for each allocator.
+ T *allocate(const std::size_t n) const {
+ // The return value of allocate(0) is unspecified.
+ // Mallocator returns NULL in order to avoid depending
+ // on malloc(0)'s implementation-defined behavior
+ // (the implementation can define malloc(0) to return NULL,
+ // in which case the bad_alloc check below would fire).
+ // All allocators can return NULL in this case.
+ if (n == 0) {
+ return nullptr;
+ }
+
+ // All allocators should contain an integer overflow check.
+ // The Standardization Committee recommends that std::length_error
+ // be thrown in the case of integer overflow.
+ if (n > max_size()) {
+ throw std::length_error(
+ "aligned_allocator<T>::allocate() - Integer overflow.");
+ }
+
+ // Mallocator wraps malloc().
+ void *pv = nullptr;
+ posix_memalign(&pv, Alignment, n * sizeof(T));
+ // pv = aligned_alloc(Alignment, n * sizeof(T));
+
+ // Allocators should throw std::bad_alloc in the case of memory allocation
+ // failure.
+ if (pv == nullptr) {
+ throw std::bad_alloc();
+ }
+
+ return static_cast<T *>(pv);
+ }
+
+ void deallocate(T *const p, const std::size_t /*n*/) const { free(p); }
+
+ // The following will be the same for all allocators that ignore hints.
+ template <typename U>
+ T *allocate(const std::size_t n, const U * /* const hint */) const {
+ return allocate(n);
+ }
+
+ // Allocators are not required to be assignable, so
+ // all allocators should have a private unimplemented
+ // assignment operator. Note that this will trigger the
+ // off-by-default (enabled under /Wall) warning C4626
+ // "assignment operator could not be generated because a
+ // base class assignment operator is inaccessible" within
+ // the STL headers, but that warning is useless.
+private:
+ aligned_allocator &operator=(const aligned_allocator &) { assert(0); }
+};
+
+template <typename T>
+using aligned_vector = std::vector<T, aligned_allocator<T, 64> >;
diff --git a/bench/BenchUtils.cc b/bench/BenchUtils.cc
new file mode 100644
index 0000000..5dade2a
--- /dev/null
+++ b/bench/BenchUtils.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "BenchUtils.h"
+#include <random>
+
+namespace fbgemm2 {
+
+std::default_random_engine eng;
+
+template <typename T>
+void randFill(aligned_vector<T> &vec, const int low, const int high) {
+ std::random_device r;
+ std::uniform_int_distribution<int> dis(low, high);
+ for (auto &v : vec) {
+ v = static_cast<T>(dis(eng));
+ }
+}
+
+template
+void randFill<float>(aligned_vector<float> &vec,
+ const int low, const int high);
+template
+void randFill<uint8_t>(aligned_vector<uint8_t> &vec,
+ const int low, const int high);
+template
+void randFill<int8_t>(aligned_vector<int8_t> &vec,
+ const int low, const int high);
+
+template
+void randFill<int>(aligned_vector<int> &vec,
+ const int low, const int high);
+
+void llc_flush(std::vector<char>& llc) {
+ volatile char* data = llc.data();
+ for (int i = 0; i < llc.size(); i++) {
+ data[i]++;
+ }
+}
+
+} // namespace fbgemm2
diff --git a/bench/BenchUtils.h b/bench/BenchUtils.h
new file mode 100644
index 0000000..5dd452e
--- /dev/null
+++ b/bench/BenchUtils.h
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <vector>
+#include "bench/AlignedVec.h"
+
+namespace fbgemm2 {
+
+template <typename T>
+void randFill(aligned_vector<T> &vec, const int low, const int high);
+
+void llc_flush(std::vector<char>& llc);
+
+} // namespace fbgemm2
diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt
new file mode 100644
index 0000000..338c37c
--- /dev/null
+++ b/bench/CMakeLists.txt
@@ -0,0 +1,43 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+find_package(MKL)
+#benchmarks
+macro(add_benchmark BENCHNAME)
+ add_executable(${BENCHNAME} ${ARGN}
+ BenchUtils.cc ../test/QuantizationHelpers.cc)
+ set_target_properties(${BENCHNAME} PROPERTIES
+ CXX_STANDARD 11
+ CXX_EXTENSIONS NO)
+ target_compile_options(${BENCHNAME} PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel")
+ target_link_libraries(${BENCHNAME} fbgemm)
+ add_dependencies(${BENCHNAME} fbgemm)
+ if(${MKL_FOUND})
+ target_include_directories(${BENCHNAME} PRIVATE "${MKL_INCLUDE_DIR}")
+ target_link_libraries(${BENCHNAME} "${MKL_LIBRARIES}")
+ target_compile_options(${BENCHNAME} PRIVATE
+ "-DUSE_MKL")
+ endif()
+ set_target_properties(${BENCHNAME} PROPERTIES FOLDER test)
+endmacro()
+
+if(FBGEMM_BUILD_BENCHMARKS)
+
+ set(BENCHMARKS "")
+
+ file(GLOB BENCH_LIST "*Benchmark.cc")
+ foreach(BENCH_FILE ${BENCH_LIST})
+ get_filename_component(BENCH_NAME "${BENCH_FILE}" NAME_WE)
+ get_filename_component(BENCH_FILE_ONLY "${BENCH_FILE}" NAME)
+ add_benchmark("${BENCH_NAME}"
+ "${BENCH_FILE_ONLY}")
+ list(APPEND BENCHMARKS "${BENCH_NAME}")
+ endforeach()
+
+ add_custom_target(run_benchmarks
+ COMMAND ${BENCHMARKS})
+
+ add_dependencies(run_benchmarks
+ ${BENCHMARKS})
+
+endif()
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
new file mode 100644
index 0000000..b7c7d44
--- /dev/null
+++ b/bench/Depthwise3DBenchmark.cc
@@ -0,0 +1,265 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "test/I8DepthwiseTest.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <cstdio>
+#include <iostream>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "AlignedVec.h"
+#include "src/FbgemmI8Depthwise.h"
+#include "fbgemm/Utils.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+int main() {
+ // Depthwise is memory BW bound so we want to flush LLC.
+ bool flush = true;
+ std::vector<char> llc;
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+#define llc_flush() \
+ for (auto i = 0; i < llc.size(); i++) { \
+ llc[i]++; \
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 16;
+
+ for (auto shape : shapes_3d) {
+ int N = shape[0];
+ int K = shape[1];
+ int T = shape[2];
+ int H = shape[3];
+ int W = shape[4];
+ int stride_t = shape[5];
+ int stride_h = stride_t;
+ int stride_w = stride_t;
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * T * H * W * K);
+ aligned_vector<int8_t> B(K * K_T * K_H * K_W);
+ aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
+ C(C_ref.size());
+
+ randFill(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ randFill(B, -16, 16);
+ int32_t B_zero_point = 5;
+
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
+ int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+
+ float C_multiplier = 255. / (maximum - minimum);
+
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+ int32_t C_zero_point = 5;
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ B.data(),
+ C_multiplier,
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3x3ConvMatrix Bp(K, B.data());
+
+ double ttot = 0;
+ double bytes =
+ double(NITER) *
+ (K * (N * (2. * sizeof(int32_t) * T_OUT * H_OUT * W_OUT + T * H * W) +
+ K_T * K_H * K_W));
+ double ops =
+ double(NITER) * N * T_OUT * H_OUT * W_OUT * K * K_T * K_H * K_W * 2;
+ chrono::time_point<chrono::system_clock> t_begin, t_end;
+ for (int i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush();
+
+ t_begin = chrono::system_clock::now();
+#pragma omp parallel
+ {
+#if _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ depthwise_3x3x3_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ Bp,
+ C.data(),
+ tid,
+ num_threads);
+ }
+ t_end = chrono::system_clock::now();
+ if (i >= NWARMUP) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int g = 0; g < K; ++g) {
+ int32_t expected =
+ C_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g];
+ int32_t actual =
+ C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g];
+ if (expected != actual) {
+ cerr << "Depthwise 3x3 results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << g << "). expected "
+ << expected << " actual " << actual << endl;
+ return -1;
+ }
+ assert(expected == actual);
+ }
+ } // w
+ } // h
+ } // t
+ } // n
+
+ // Report performance
+ printf("N = %d K = %d T = %d H = %d W = %d stride = %d\n", N, K, T, H, W,
+ stride_h);
+ printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9);
+
+ ttot = 0;
+ for (int i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush();
+
+ t_begin = chrono::system_clock::now();
+#pragma omp parallel
+ {
+#if _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ depthwise_3x3x3_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_uint8.data(),
+ col_offsets.data(),
+ bias.data(),
+ false /* fuse_relu */,
+ tid,
+ num_threads);
+ }
+ t_end = chrono::system_clock::now();
+ if (i >= NWARMUP) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int g = 0; g < K; ++g) {
+ uint8_t expected =
+ C_uint8_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K +
+ g];
+ uint8_t actual =
+ C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g];
+ if (expected != actual) {
+ cerr << "Depthwise 3x3 results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << g << "). expected "
+ << (int)expected << " actual " << (int)actual << endl;
+ return -1;
+ }
+ assert(expected == actual);
+ }
+ } // w
+ } // h
+ } // t
+ } // n
+
+ // Report performance
+ printf("N = %d K = %d T = %d H = %d W = %d stride = %d with requantization "
+ "fused\n",
+ N, K, T, H, W, stride_h);
+ printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9);
+ } // for each shape
+
+ return 0;
+}
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
new file mode 100644
index 0000000..de08ff7
--- /dev/null
+++ b/bench/DepthwiseBenchmark.cc
@@ -0,0 +1,342 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <cstdio>
+#include <iostream>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "AlignedVec.h"
+#include "src/FbgemmI8Depthwise.h"
+#include "fbgemm/Utils.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+int main() {
+ // From Xray OCR
+ vector<vector<int>> shapes = {
+ // N, G, H_in, W_in, stride
+ { 1, 272, 47, 125, 1, },
+ { 1, 272, 64, 125, 1, },
+ { 1, 272, 66, 125, 1, },
+ { 1, 272, 67, 100, 1, },
+ { 1, 272, 71, 125, 1, },
+ { 1, 272, 74, 125, 1, },
+ { 1, 272, 75, 75, 1, },
+ { 1, 272, 75, 76, 1, },
+ { 1, 272, 75, 79, 1, },
+ { 1, 272, 75, 85, 1, },
+ { 1, 272, 75, 100, 1, },
+ { 1, 272, 75, 103, 1, },
+ { 1, 272, 75, 111, 1, },
+ { 1, 272, 75, 113, 1, },
+ { 1, 272, 94, 75, 1, },
+ { 1, 272, 109, 75, 1, },
+ { 1, 272, 113, 75, 1, },
+ { 1, 272, 117, 75, 1, },
+ { 1, 544, 24, 63, 1, },
+ { 1, 544, 32, 63, 1, },
+ { 1, 544, 33, 63, 1, },
+ { 1, 544, 34, 50, 1, },
+ { 1, 544, 36, 63, 1, },
+ { 1, 544, 37, 63, 1, },
+ { 1, 544, 38, 38, 1, },
+ { 1, 544, 38, 40, 1, },
+ { 1, 544, 38, 43, 1, },
+ { 1, 544, 38, 50, 1, },
+ { 1, 544, 38, 52, 1, },
+ { 1, 544, 38, 56, 1, },
+ { 1, 544, 38, 57, 1, },
+ { 1, 544, 47, 38, 1, },
+ { 1, 544, 55, 38, 1, },
+ { 1, 544, 57, 38, 1, },
+ { 1, 544, 59, 38, 1, },
+ { 1, 1088, 7, 7, 1, },
+ { 51, 1088, 7, 7, 1, },
+ { 59, 1088, 7, 7, 1, },
+ { 70, 1088, 7, 7, 1, },
+ { 71, 1088, 7, 7, 1, },
+ { 77, 1088, 7, 7, 1, },
+ { 79, 1088, 7, 7, 1, },
+ { 84, 1088, 7, 7, 1, },
+ { 85, 1088, 7, 7, 1, },
+ { 89, 1088, 7, 7, 1, },
+ { 93, 1088, 7, 7, 1, },
+ { 96, 1088, 7, 7, 1, },
+ { 100, 1088, 7, 7, 1, },
+
+ { 1, 248, 93, 250, 2, },
+ { 1, 248, 128, 250, 2, },
+ { 1, 248, 132, 250, 2, },
+ { 1, 248, 131, 250, 2, },
+ { 1, 248, 133, 200, 2, },
+ { 1, 248, 141, 250, 2, },
+ { 1, 248, 148, 250, 2, },
+ { 1, 248, 150, 150, 2, },
+ { 1, 248, 150, 151, 2, },
+ { 1, 248, 150, 158, 2, },
+ { 1, 248, 150, 169, 2, },
+ { 1, 248, 150, 200, 2, },
+ { 1, 248, 150, 205, 2, },
+ { 1, 248, 150, 221, 2, },
+ { 1, 248, 150, 225, 2, },
+ { 1, 248, 188, 150, 2, },
+ { 1, 248, 218, 150, 2, },
+ { 1, 248, 225, 150, 2, },
+ { 1, 248, 234, 150, 2, },
+ { 1, 272, 47, 125, 2, },
+ { 1, 272, 64, 125, 2, },
+ { 1, 272, 66, 125, 2, },
+ { 1, 272, 67, 100, 2, },
+ { 1, 272, 71, 125, 2, },
+ { 1, 272, 74, 125, 2, },
+ { 1, 272, 75, 75, 2, },
+ { 1, 272, 75, 76, 2, },
+ { 1, 272, 75, 79, 2, },
+ { 1, 272, 75, 85, 2, },
+ { 1, 272, 75, 100, 2, },
+ { 1, 272, 75, 103, 2, },
+ { 1, 272, 75, 111, 2, },
+ { 1, 272, 75, 113, 2, },
+ { 1, 272, 94, 75, 2, },
+ { 1, 272, 109, 75, 2, },
+ { 1, 272, 113, 75, 2, },
+ { 1, 272, 117, 75, 2, },
+ { 1, 544, 14, 14, 2, },
+ { 51, 544, 14, 14, 2, },
+ { 59, 544, 14, 14, 2, },
+ { 70, 544, 14, 14, 2, },
+ { 71, 544, 14, 14, 2, },
+ { 77, 544, 14, 14, 2, },
+ { 79, 544, 14, 14, 2, },
+ { 84, 544, 14, 14, 2, },
+ { 85, 544, 14, 14, 2, },
+ { 89, 544, 14, 14, 2, },
+ { 93, 544, 14, 14, 2, },
+ { 96, 544, 14, 14, 2, },
+ { 100, 544, 14, 14, 2, },
+ };
+
+ // Depthwise is memory BW bound so we want to flush LLC.
+ bool flush = true;
+ std::vector<char> llc;
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+#define llc_flush() \
+ for (auto i = 0; i < llc.size(); i++) { \
+ llc[i]++; \
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 16;
+
+ for (auto shape : shapes) {
+ int N = shape[0];
+ int G = shape[1];
+ int H = shape[2];
+ int W = shape[3];
+ int stride_h = shape[4];
+ int stride_w = stride_h;
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * H * W * G);
+ aligned_vector<int8_t> B(G * R * S);
+ aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * G), C(C_ref.size());
+
+ randFill(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ randFill(B, -16, 16);
+ int32_t B_zero_point = 5;
+
+ depthwise_3x3_pad_1_ref(
+ N,
+ H,
+ W,
+ G,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
+ int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+
+ float C_multiplier = 255. / (maximum - minimum);
+
+ aligned_vector<int32_t> col_offsets(G);
+ aligned_vector<int32_t> bias(G);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+ int32_t C_zero_point = 5;
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3_pad_1_ref(
+ N,
+ H,
+ W,
+ G,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ B.data(),
+ C_multiplier,
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3ConvMatrix Bp(G, B.data());
+
+ double ttot = 0;
+ double bytes =
+ double(NITER) *
+ (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S));
+ double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2;
+ chrono::time_point<chrono::system_clock> t_begin, t_end;
+ for (int i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush();
+
+ t_begin = chrono::system_clock::now();
+#pragma omp parallel
+ {
+#ifdef _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ depthwise_3x3_pad_1(
+ N,
+ H,
+ W,
+ G,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ Bp,
+ C.data(),
+ tid,
+ num_threads);
+ }
+ t_end = chrono::system_clock::now();
+ if (i >= NWARMUP) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int g = 0; g < G; ++g) {
+ int32_t expected = C_ref[((n * H_OUT + h) * W_OUT + w) * G + g];
+ int32_t actual = C[((n * H_OUT + h) * W_OUT + w) * G + g];
+ if (expected != actual) {
+ cerr << "Depthwise 3x3 results differ at (" << n << ", "
+ << h << ", " << w << ", " << g << "). expected "
+ << expected << " actual " << actual << endl;
+ return -1;
+ }
+ assert(expected == actual);
+ }
+ }
+ }
+ }
+
+ // Report performance
+ printf("N = %d G = %d H = %d W = %d stride = %d\n", N, G, H, W, stride_h);
+ printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9);
+
+ ttot = 0;
+ for (int i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush();
+
+ t_begin = chrono::system_clock::now();
+#pragma omp parallel
+ {
+#ifdef _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ depthwise_3x3_pad_1(
+ N,
+ H,
+ W,
+ G,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_uint8.data(),
+ col_offsets.data(),
+ bias.data(),
+ tid,
+ num_threads);
+ }
+ t_end = chrono::system_clock::now();
+ if (i >= NWARMUP) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int g = 0; g < G; ++g) {
+ uint8_t expected =
+ C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g];
+ uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g];
+ if (expected != actual) {
+ cerr << "Depthwise 3x3 results differ at (" << n << ", "
+ << h << ", " << w << ", " << g << "). expected "
+ << (int)expected << " actual " << (int)actual << endl;
+ return -1;
+ }
+ assert(expected == actual);
+ }
+ }
+ }
+ }
+
+ // Report performance
+ printf(
+ "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n",
+ N, G, H, W, stride_h);
+ printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9);
+ } // for each shape
+
+ return 0;
+}
diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc
new file mode 100644
index 0000000..f5ec10f
--- /dev/null
+++ b/bench/FP16Benchmark.cc
@@ -0,0 +1,215 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <chrono>
+#include <cmath>
+#include <random>
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "bench/BenchUtils.h"
+#include "fbgemm/FbgemmFP16.h"
+#include "AlignedVec.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+void performance_test() {
+ // cache flush
+ bool flush = true;
+ std::vector<char> llc;
+ if (flush) {
+ llc.resize(64L * 1024L * 1024L, 1.0);
+ }
+
+ float alpha = 1.f, beta = 1.f;
+ matrix_op_t btran = matrix_op_t::Transpose;
+
+ using btype = float16;
+
+#define dataset 1
+
+#if dataset == 1
+ const int NITER = (flush) ? 10 : 100;
+ std::vector<std::vector<int>> shapes;
+ for (auto m = 1; m < 120; m++) {
+ // shapes.push_back({m, 128, 512});
+ shapes.push_back({m, 512, 512});
+ }
+
+#elif dataset == 2
+ const int NITER = (flush) ? 10 : 100;
+#include "shapes_dataset.h"
+
+#else
+ flush = false;
+ constexpr int NITER = 1;
+ std::vector<std::vector<int>> shapes;
+ std::random_device r;
+ std::default_random_engine generator(r());
+ std::uniform_int_distribution<int> dm(1, 100);
+ std::uniform_int_distribution<int> dnk(1, 1024);
+ for (int i = 0; i < 1000; i++) {
+ int m = dm(generator);
+ int n = dnk(generator);
+ int k = dnk(generator);
+ shapes.push_back({m, n, k});
+ }
+#endif
+
+ std::string type;
+ double gflops, gbs, ttot;
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+ int k = s[2];
+
+ aligned_vector<float> A(m * k, 0.f);
+ aligned_vector<float> B(k * n, 0.f);
+ aligned_vector<float> Cg(m * n, 1.f);
+ aligned_vector<float> Cp(m * n, NAN);
+
+ // initialize with small numbers
+ randFill(A, 0, 4);
+
+ randFill(B, 0, 4);
+ PackedGemmMatrixFP16 Bp(btran, k, n, alpha, B.data());
+
+ if (beta != 0.0f) {
+ randFill(Cg, 0, 4);
+ Cp = Cg;
+ }
+
+ double nflops = 2.0 * (double)m * (double)n * (double)k * (double)NITER;
+ double nbytes = (4.0 * (double)m * (double)k + 2.0 * (double)k * (double)n +
+ 4.0 * (double)m * (double)n) *
+ NITER;
+
+ // warm up MKL and fbgemm
+ // check correctness at the same time
+ for (auto w = 0; w < 3; w++) {
+#ifdef USE_MKL
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ btran == matrix_op_t::Transpose ? CblasTrans : CblasNoTrans,
+ m,
+ n,
+ k,
+ alpha,
+ A.data(),
+ k,
+ B.data(),
+ (btran == matrix_op_t::NoTranspose) ? n : k,
+ beta,
+ Cg.data(),
+ n);
+#endif
+ cblas_gemm_compute(
+ matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data());
+
+#ifdef USE_MKL
+ // Compare results
+ for (auto i = 0; i < Cg.size(); i++) {
+ // printf("%f %f\n", Cg[i], Cp[i]);
+ assert(std::abs(Cg[i] - Cp[i]) < 1e-3);
+ }
+#endif
+ }
+
+ chrono::time_point<chrono::system_clock> t_begin, t_end;
+#ifdef USE_MKL
+ // Gold via MKL sgemm
+ type = "MKL_FP32";
+ ttot = 0;
+ for (auto it = -3; it < NITER; it++) {
+ if (flush) {
+ for (auto i = 0; i < llc.size(); i++) {
+ llc[i]++;
+ }
+ }
+ t_begin = chrono::system_clock::now();
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ btran == matrix_op_t::Transpose ? CblasTrans : CblasNoTrans,
+ m,
+ n,
+ k,
+ alpha,
+ A.data(),
+ k,
+ B.data(),
+ (btran == matrix_op_t::NoTranspose) ? n : k,
+ beta,
+ Cg.data(),
+ n);
+ t_end = chrono::system_clock::now();
+ if (it >= 0) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+ gflops = nflops / ttot / 1e9;
+ gbs = nbytes / ttot / 1e9;
+ printf(
+ "\n%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
+ type.c_str(),
+ m,
+ n,
+ k,
+ gflops,
+ gbs);
+ ((volatile char*)(llc.data()));
+#endif
+
+ type = "FBP_" + std::string(typeid(btype).name());
+
+ ttot = 0;
+ for (auto it = -3; it < NITER; it++) {
+ if (flush) {
+ for (auto i = 0; i < llc.size(); i++) {
+ llc[i]++;
+ }
+ }
+
+ t_begin = chrono::system_clock::now();
+ cblas_gemm_compute(
+ matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data());
+ t_end = chrono::system_clock::now();
+
+ if (it >= 0) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ ttot += dt;
+ }
+ }
+ gflops = nflops / ttot / 1e9;
+ gbs = nbytes / ttot / 1e9;
+ printf(
+ "%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
+ type.c_str(),
+ m,
+ n,
+ k,
+ gflops,
+ gbs);
+ ((volatile char*)(llc.data()));
+ }
+}
+
+int main(int /*argc*/, char** /*argv*/) {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+
+ performance_test();
+}
diff --git a/bench/I8SpmdmBenchmark.cc b/bench/I8SpmdmBenchmark.cc
new file mode 100644
index 0000000..f97d152
--- /dev/null
+++ b/bench/I8SpmdmBenchmark.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <chrono>
+#include <cstdlib>
+#include <iostream>
+#include <numeric>
+#include <random>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "fbgemm/FbgemmI8Spmdm.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+int main() {
+ const vector<array<int, 3>> shapes = {
+ // M, N, K
+ {1024, 1024, 1024},
+ {511, 512, 512},
+ };
+
+ // SpMDM is often memory BW bound so we want to flush LLC.
+ bool flush = true;
+ std::vector<char> llc;
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 16;
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "M, "
+ << "N, "
+ << "K, "
+ << "Density, "
+ << "Accumulation, "
+ << "Initialize (ms), "
+ << "Transpose uint8 (ms), "
+ << "Transpose 32xN (ms), "
+ << "Compute (ms), "
+ << "Transpose 32xN (ms), "
+ << "Total (ms), "
+ << "GB/s, "
+ << "GOPs" << endl;
+#else
+ cout << "M, "
+ << "N, "
+ << "K, "
+ << "Density, "
+ << "Accumulation, "
+ << "GB/s, "
+ << "GOPs" << endl;
+#endif
+
+ for (const auto& shape : shapes) {
+ for (float density : {0.0001f, 0.001f, 0.01f, 0.1f, 1.0f}) {
+ for (bool accumulation : {false, true}) {
+ int M = shape[0];
+ int N = shape[1];
+ int K = shape[2];
+
+ cout << M << ", " << N << ", " << K << ", ";
+
+ aligned_vector<uint8_t> A(M * K);
+ randFill(A, 0, 255);
+
+ fbgemm2::CompressedSparseColumn B_csc(K, N);
+ vector<int32_t> C(M * N);
+ vector<int32_t> C_ref(C.size());
+
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ C_ref[i * N + j] = i + j;
+ }
+ }
+
+ // deterministic random number
+ std::default_random_engine eng;
+ binomial_distribution<> per_col_nnz_dist(K, density);
+ uniform_int_distribution<> value_dist(
+ numeric_limits<int8_t>::min() / 2,
+ numeric_limits<int8_t>::max() / 2);
+
+ vector<int> row_indices(K);
+
+ int total_nnz = 0;
+ for (int j = 0; j < N; ++j) {
+ B_csc.ColPtr()[j] = total_nnz;
+
+ int nnz_of_j = per_col_nnz_dist(eng);
+ total_nnz += nnz_of_j;
+
+ iota(row_indices.begin(), row_indices.end(), 0);
+ shuffle(row_indices.begin(), row_indices.end(), eng);
+ sort(row_indices.begin(), row_indices.begin() + nnz_of_j);
+
+ for (int k = 0; k < nnz_of_j; ++k) {
+ B_csc.RowIdx().push_back(row_indices[k]);
+ B_csc.Values().push_back(value_dist(eng));
+ }
+ }
+ B_csc.ColPtr()[N] = total_nnz;
+
+ double ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double total_initial_time = 0.0;
+ double total_transpose_uint8_time = 0.0;
+ double total_transpose_32xN_time = 0.0;
+ double total_compute_time = 0.0;
+ double total_transpose_Nx32_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+ double ops = double(NITER) * B_csc.NumOfNonZeros() * M * 2;
+ double bytes = double(NITER) *
+ (M * N * sizeof(int32_t) + M * K +
+ B_csc.NumOfNonZeros() * (sizeof(int16_t) + sizeof(int8_t)) +
+ B_csc.ColPtr().size() * sizeof(int32_t));
+
+ spmdm_ref(M, A.data(), K, B_csc, accumulation, C_ref.data(), N);
+
+ chrono::time_point<chrono::system_clock> t_begin, t_end;
+ for (int iter = 0; iter < NWARMUP + NITER; ++iter) {
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ C[i * N + j] = i + j;
+ }
+ }
+ llc_flush(llc);
+
+ t_begin = chrono::system_clock::now();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ spmdm_initial_time = 0.0;
+ spmdm_transpose_uint8_time = 0.0;
+ spmdm_transpose_32xN_time = 0.0;
+ spmdm_compute_time = 0.0;
+ spmdm_transpose_Nx32_time = 0.0;
+ spmdm_run_time = 0.0;
+#endif
+
+#ifndef FBGEMM_MEASURE_TIME_BREAKDOWN
+#pragma omp parallel
+#endif
+ {
+#if defined (FBGEMM_MEASURE_TIME_BREAKDOWN) || !defined(_OPENMP)
+ int num_threads = 1;
+ int tid = 0;
+#else
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#endif
+ int i_per_thread =
+ ((M + 31) / 32 + num_threads - 1) / num_threads * 32;
+ int i_begin = std::min(tid * i_per_thread, M);
+ int i_end = std::min(i_begin + i_per_thread, M);
+
+ block_type_t block = {i_begin, i_end - i_begin, 0, N};
+ B_csc.SpMDM(
+ block, A.data(), K, accumulation, C.data() + i_begin * N, N);
+ }
+ t_end = chrono::system_clock::now();
+ if (iter >= NWARMUP) {
+ double dt = chrono::duration<double>(t_end - t_begin).count();
+ // double dt = chrono::duration_cast<chrono::nanoseconds>(t_end -
+ // t_begin).count();
+ ttot += dt;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_initial_time += spmdm_initial_time;
+ total_transpose_uint8_time += spmdm_transpose_uint8_time;
+ total_transpose_32xN_time += spmdm_transpose_32xN_time;
+ total_compute_time += spmdm_compute_time;
+ total_transpose_Nx32_time += spmdm_transpose_Nx32_time;
+ total_run_time += spmdm_run_time;
+#endif
+ }
+ }
+
+ compare_buffers(C_ref.data(), C.data(), M, N, N, 5, 0);
+
+ cout << fixed << B_csc.Density() << ", " << accumulation << ", ";
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << total_initial_time / (double)NITER / 1e6 << ", "
+ << total_transpose_uint8_time / (double)NITER / 1e6 << ", "
+ << total_transpose_32xN_time / (double)NITER / 1e6 << ", "
+ << total_compute_time / (double)NITER / 1e6 << ", "
+ << total_transpose_Nx32_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", ";
+#endif
+ // Report performance
+ cout << fixed << bytes / ttot / 1e9 << ", " << ops / ttot / 1e9 << endl;
+
+ } // accumulation
+ } // for each density
+ } // for each shape
+
+ return 0;
+}
diff --git a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
new file mode 100644
index 0000000..62010ec
--- /dev/null
+++ b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+void performance_test() {
+ vector<conv_param_t> shapes = {
+ // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ),
+ };
+
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+ chrono::time_point<chrono::high_resolution_clock> begin, end;
+ for (auto conv_p : shapes) {
+ aligned_vector<float> Afp32(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f);
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+
+ aligned_vector<uint8_t> Aint8_out(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC,
+ 0);
+
+ aligned_vector<float> Bfp32(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f);
+ aligned_vector<int8_t> Bint8(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_fb2(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ cout << conv_p.toString() << endl;
+
+ // A matrix (input activations)
+ randFill(Afp32, 0, 5);
+ int32_t Aint8_zero_point = 4;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Aint8[i] = static_cast<uint8_t>(Afp32[i]);
+ }
+
+ // B matrix (weights)
+ randFill(Bfp32, -4, 4);
+ // int32_t Bint8_zero_point = -3;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bint8[i] = static_cast<int8_t>(Bfp32[i]);
+ }
+
+ // reference implementation
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ // matrix dimensions after im2col
+ int MDim = conv_p.MB * conv_p.OH * conv_p.OW;
+ int NDim = conv_p.OC;
+ int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim,
+ // "B unpacked");
+ // packedB.printPackedMatrix("B Packed");
+
+ double ttot = 0;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int16_t> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ ttot = 0;
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+ }
+ }
+ cout << fixed << "fused im2col GOPs: "
+ << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+
+ compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5);
+
+ ttot = 0;
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data());
+
+ // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim,
+ // KDim, "A_out after im2col unpacked");
+
+ PackAWithRowOffset<uint8_t, int16_t> packAN(
+ matrix_op_t::NoTranspose,
+ MDim,
+ KDim,
+ Aint8_out.data(),
+ KDim,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ fbgemmPacked(
+ packAN,
+ packedB,
+ Cint32_fb2.data(),
+ Cint32_fb2.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+ }
+ }
+ ((volatile char*)(llc.data()));
+
+ // packedB.printPackedMatrix("bench B Packed");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_fb.data(), MDim, NDim, NDim,
+ // "C fb fp32");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_fb2.data(),
+ // MDim, NDim, NDim, "C fb2 fp32");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
+
+ cout << fixed << "unfused im2col GOPs: "
+ << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+ // cout << "total time: " << ttot << " ns" << endl;
+ compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5);
+ } // shapes
+}
+
+int main() {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
new file mode 100644
index 0000000..9adea49
--- /dev/null
+++ b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+void performance_test() {
+ vector<conv_param_t> shapes = {
+ // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ),
+ };
+
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+ chrono::time_point<chrono::high_resolution_clock> begin, end;
+ for (auto conv_p : shapes) {
+ aligned_vector<float> Afp32(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f);
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+
+ aligned_vector<uint8_t> Aint8_out(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC,
+ 0);
+
+ aligned_vector<float> Bfp32(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f);
+ aligned_vector<int8_t> Bint8(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_fb2(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ cout << conv_p.toString() << endl;
+
+ // A matrix (input activations)
+ randFill(Afp32, 0, 5);
+ int32_t Aint8_zero_point = 4;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Aint8[i] = static_cast<uint8_t>(Afp32[i]);
+ }
+
+ // B matrix (weights)
+ randFill(Bfp32, -4, 4);
+ // int32_t Bint8_zero_point = -3;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bint8[i] = static_cast<int8_t>(Bfp32[i]);
+ }
+
+ // reference implementation
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ // matrix dimensions after im2col
+ int MDim = conv_p.MB * conv_p.OH * conv_p.OW;
+ int NDim = conv_p.OC;
+ int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim,
+ // "B unpacked");
+ // packedB.printPackedMatrix("B Packed");
+
+ double ttot = 0;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int32_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int32_t> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int32_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ ttot = 0;
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+ }
+ }
+ cout << fixed << "fused im2col GOPs: "
+ << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+
+ compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5);
+
+ ttot = 0;
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data());
+
+ // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim,
+ // KDim, "A_out after im2col unpacked");
+
+ PackAWithRowOffset<uint8_t, int32_t> packAN(
+ matrix_op_t::NoTranspose,
+ MDim,
+ KDim,
+ Aint8_out.data(),
+ KDim,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ fbgemmPacked(
+ packAN,
+ packedB,
+ Cint32_fb2.data(),
+ Cint32_fb2.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+ }
+ }
+
+ ((volatile char*)(llc.data()));
+
+ // packedB.printPackedMatrix("bench B Packed");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_fb.data(), MDim, NDim, NDim,
+ // "C fb fp32");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_fb2.data(),
+ // MDim, NDim, NDim, "C fb2 fp32");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
+
+ cout << fixed << "unfused im2col GOPs: "
+ << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+ // cout << "total time: " << ttot << " ns" << endl;
+ compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5);
+ } // shapes
+}
+
+int main() {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc
new file mode 100644
index 0000000..4a2eda4
--- /dev/null
+++ b/bench/PackedFloatInOutBenchmark.cc
@@ -0,0 +1,269 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "test/QuantizationHelpers.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+void performance_test() {
+ vector<vector<int>> shapes = {
+ {1, 128, 512},
+ {1, 1024, 256},
+ {1, 2048, 512},
+ {1, 4096, 1024},
+
+ {6, 256, 1024},
+ {6, 256, 2048},
+ {6, 512, 512},
+ {6, 1024, 256},
+ {6, 2048, 256},
+ {6, 2048, 512},
+ {6, 4096, 256},
+ {6, 4096, 1024},
+ {6, 4096, 2048},
+
+ {10, 2048, 256},
+ {10, 4096, 1024},
+
+ {20, 2048, 256},
+ {20, 4096, 1024},
+
+ {102, 1024, 512},
+ {102, 2323, 256},
+ {102, 512, 256},
+
+ {1, 800, 3200},
+ {1, 800, 8000},
+
+ {16, 256, 1500},
+ {16, 256, 1567},
+ {1, 128, 2876},
+ {16, 128, 1567},
+ {1, 128, 2722},
+ {16, 256, 512},
+ };
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+ cout << setw(7) << "M, " << setw(7) << "N, " << setw(7) << "K, " << setw(18)
+ << "Type, " << setw(5) << "GOPS" << endl;
+
+ chrono::time_point<chrono::high_resolution_clock> start, end;
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ float alpha = 1.f, beta = 0.f;
+ aligned_vector<float> Afp32(m * k, 0.0f);
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ aligned_vector<float> Bfp32(k * n, 0.0f);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+
+ aligned_vector<float> Cfp32_mkl(m * n, 0.0f);
+ aligned_vector<float> Cfp32_fb(m * n, 0.0f);
+
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+
+ // A matrix
+ randFill(Aint8, 0, 255);
+ float Aint8_scale = 0.11;
+ int32_t Aint8_zero_point = 43;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Afp32[i] = Aint8_scale * (Aint8[i] - Aint8_zero_point);
+ }
+
+ randFill(Bint8, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8.data());
+
+ float Bint8_scale = 0.49;
+ int32_t Bint8_zero_point = -30;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bfp32[i] = Bint8_scale * (Bint8[i] - Bint8_zero_point);
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data());
+
+ double ttot = 0;
+ std::string type;
+ double nops = 2.0 * (double)m * (double)n * (double)k * (double)NITER;
+#ifdef USE_MKL
+ type = "MKL_FP32";
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ start = chrono::high_resolution_clock::now();
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ CblasNoTrans,
+ m,
+ n,
+ k,
+ alpha,
+ Afp32.data(),
+ k,
+ Bfp32.data(),
+ n,
+ beta,
+ Cfp32_mkl.data(),
+ n);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - start);
+ ttot += dur.count();
+ }
+ }
+ ((volatile char*)(llc.data()));
+ cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k << ", "
+ << setw(16) << type << ", " << setw(5) << fixed << setw(5)
+ << setprecision(1) << nops / ttot << endl;
+#endif
+
+ int32_t C_multiplier = 16544;
+ int32_t C_right_shift = 35;
+ int32_t C_zero_pt = 5;
+
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B
+ // unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Cfp32_mkl.data(),
+ // m, n, n, "C mkl fp32");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint8_local.data(), m, n, n, "C requantized");
+ // printMatrix(matrix_op_t::NoTranspose, col_offsets.data(), 1, n, n, "col
+ // offsets before");
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize());
+
+ PackAWithQuantRowOffset<uint8_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k,
+ Afp32.data(),
+ k,
+ nullptr, /*buffer for packed matrix*/
+ Aint8_scale,
+ Aint8_zero_point,
+ 1, /*groups*/
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t> packedBN(
+ matrix_op_t::NoTranspose,
+ k,
+ n,
+ Bint8.data(),
+ n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<float, float> doNothingObj{};
+ ReQuantizeForFloat<false> outputProcObj(
+ doNothingObj,
+ Aint8_scale,
+ Bint8_scale,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ ttot = 0;
+ type = "FBGEMM_i8_acc32";
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ start = chrono::high_resolution_clock::now();
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cfp32_fb.data(),
+ (int32_t*)Cfp32_fb.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - start);
+ ttot += dur.count();
+ }
+ }
+ ((volatile char*)(llc.data()));
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B
+ // unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(),
+ // m, n, n, "C requantized after");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint8_fb.data(), m, n, n, "C fb");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // col_offsets.data(), 1, n, n, "col offsets after");
+ // compare_buffers(row_offsets.data(), row_offset_buf.data(),
+ // row_offsets.size(), 5);
+ // printMatrix(matrix_op_t::NoTranspose, Cfp32_fb.data(),
+ // m, n, n, "C fb fp32");
+ cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k << ", "
+ << setw(16) << type << ", " << setw(5) << fixed << setw(5)
+ << setprecision(1) << nops / ttot << endl;
+ cout << endl;
+ // cout << "total time: " << ttot << " ns" << endl;
+
+ float maximum = *max_element(Cfp32_mkl.begin(), Cfp32_mkl.end());
+ float minimum = *min_element(Cfp32_mkl.begin(), Cfp32_mkl.end());
+ float atol = (maximum - minimum) / 255 / 1.9;
+
+#ifdef USE_MKL
+ // correctness check
+ compare_buffers(Cfp32_mkl.data(), Cfp32_fb.data(), m, n, n, 5, atol);
+#endif
+ }
+}
+
+int main(int /* unused */, char** /* unused */) {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc
new file mode 100644
index 0000000..a758f55
--- /dev/null
+++ b/bench/PackedRequantizeAcc16Benchmark.cc
@@ -0,0 +1,439 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+enum class BenchmarkType {
+ BARE_BONE, // no row-offset in input packing, and no output processing
+ REQUANTIZATION, // no row-offset in input packing, and requantization
+ ROW_OFFSET_AND_REQUANTIZATION, // row-offset in input packing, and
+ // requantization
+ EVERYTHING, // row-offset in input packing, and requantization + spmdm
+};
+
+void performance_test() {
+ vector<vector<int>> shapes = {
+ // m, n, k
+ {64, 68, 17}, {60, 128, 64}, {25088, 256, 64},
+ {25088, 64, 64}, {25088, 64, 576}, {25088, 64, 256},
+
+ {6272, 512, 256}, {6272, 128, 256}, {6272, 128, 1152},
+ {6272, 512, 128}, {6272, 128, 512},
+
+ {1568, 1024, 512}, {1568, 256, 512}, {1568, 256, 2304},
+ {1568, 1024, 256}, {1568, 256, 1024},
+
+ {392, 2048, 1024}, {392, 512, 1024}, {392, 512, 4608},
+ {392, 2048, 512}, {392, 512, 2048},
+ };
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "M, "
+ << "N, "
+ << "K, "
+ << "Output Processing, "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "Total (ms), "
+ << "GOPS" << endl;
+#else
+ cout << setw(7) << "M, " << setw(7) << "N, " << setw(7) << "K, " << setw(32)
+ << "Output Processing, " << setw(18) << "Type, " << setw(5) << "GOPS"
+ << endl;
+#endif
+
+ chrono::time_point<chrono::high_resolution_clock> begin, end;
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ float alpha = 1.f, beta = 0.f;
+ aligned_vector<float> Afp32(m * k, 0.0f);
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ aligned_vector<float> Bfp32(k * n, 0.0f);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+
+ aligned_vector<float> Cfp32_mkl(m * n, 0.0f);
+ // just used for result comparisons
+ aligned_vector<int32_t> Cint32_mkl(m * n, 0.0f);
+ // requantize results
+ aligned_vector<uint8_t> Cint8_mkl(m * n, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0.0f);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0.0f);
+
+ // A matrix
+ randFill(Afp32, 0, 50);
+ int32_t Aint8_zero_point = 43;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Aint8[i] = static_cast<uint8_t>(Afp32[i]);
+ }
+
+ randFill(Bfp32, -8, 8);
+
+ double nops = 2.0 * static_cast<double>(NITER) * m * n * k;
+ double ttot = 0.0;
+ string runType;
+
+#ifdef USE_MKL
+ ttot = 0.0;
+ runType = "MKL_fp32";
+ cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k
+ << ", ";
+ cout << setw(30) << "NA";
+ cout << ", ";
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ CblasNoTrans,
+ m,
+ n,
+ k,
+ alpha,
+ Afp32.data(),
+ k,
+ Bfp32.data(),
+ n,
+ beta,
+ Cfp32_mkl.data(),
+ n);
+ end = chrono::high_resolution_clock::now();
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+ }
+ }
+ ((volatile char*)(llc.data()));
+ cout << setw(16) << runType << ", " << fixed << setw(5) << setprecision(1)
+ << nops / ttot << endl;
+
+ for (auto i = 0; i < Cfp32_mkl.size(); ++i) {
+ Cint32_mkl[i] = static_cast<int32_t>(Cfp32_mkl[i]);
+ }
+#endif
+
+ for (BenchmarkType bench_type :
+ {BenchmarkType::BARE_BONE,
+ BenchmarkType::REQUANTIZATION,
+ BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION,
+ BenchmarkType::EVERYTHING}) {
+ // When we don't compute row_offset in fbgemm, we set B_zero_point to 0
+ // to get the same result as the reference.
+ int32_t Bint8_zero_point = (bench_type == BenchmarkType::BARE_BONE ||
+ bench_type == BenchmarkType::REQUANTIZATION)
+ ? 0
+ : -30;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bint8[i] = static_cast<int8_t>(Bfp32[i]);
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data());
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ row_offsets_u8acc32_ref(m, k, k, Aint8.data(), row_offsets.data());
+
+ float C_multiplier =
+ (bench_type == BenchmarkType::BARE_BONE) ? 1 : 0.1234;
+ int32_t C_zero_pt = (bench_type == BenchmarkType::BARE_BONE) ? 0 : 5;
+
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n,
+ // "B unpacked");
+ // packedB.printPackedMatrix("B Packed");
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+
+ cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k
+ << ", ";
+ switch (bench_type) {
+ case BenchmarkType::BARE_BONE:
+ cout << setw(30) << "bare_bone";
+ break;
+ case BenchmarkType::REQUANTIZATION:
+ cout << setw(30) << "requantization";
+ break;
+ case BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION:
+ cout << setw(30) << "row_offset_and_requantization";
+ break;
+ case BenchmarkType::EVERYTHING:
+ cout << setw(30) << "everything";
+ break;
+ };
+ cout << ", ";
+
+ requantize_u8acc32_ref(
+ m,
+ n,
+ n,
+ Cint32_mkl.data(),
+ Cint8_mkl.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ nullptr); // bias
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAMatrix<uint8_t, int16_t> packA(
+ matrix_op_t::NoTranspose,
+ m,
+ k,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point);
+ PackAWithRowOffset<uint8_t, int16_t> packAWithRowOffset(
+ matrix_op_t::NoTranspose,
+ m,
+ k,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ matrix_op_t::NoTranspose, k, n, Bint8.data(), n);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ // spmdm -> requantization -> nothing
+ // construct an output processing pipeline in reverse order
+ // i.e. last output operation first
+ // Last operation should always be DoNothing with
+ // correct input and output type.
+ DoNothing<> doNothingObj{};
+ // Requantization back to int8
+ ReQuantizeOutput<false> reqObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ bench_type == BenchmarkType::REQUANTIZATION
+ ? nullptr
+ : packAWithRowOffset.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ CompressedSparseColumn B_csc(k, n);
+
+ float density = 0.001f;
+
+ // deterministic random number
+ default_random_engine eng;
+ binomial_distribution<> per_col_nnz_dist(k, density);
+
+ vector<int> row_indices(k);
+
+ int total_nnz = 0;
+ for (int j = 0; j < n; ++j) {
+ B_csc.ColPtr()[j] = total_nnz;
+
+ int nnz_of_j = per_col_nnz_dist(eng);
+ total_nnz += nnz_of_j;
+
+ iota(row_indices.begin(), row_indices.end(), 0);
+ shuffle(row_indices.begin(), row_indices.end(), eng);
+ sort(row_indices.begin(), row_indices.begin() + nnz_of_j);
+
+ for (int kidx = 0; kidx < nnz_of_j; ++kidx) {
+ B_csc.RowIdx().push_back(row_indices[kidx]);
+ // put the current B value
+ B_csc.Values().push_back(Bint8[row_indices[kidx] * n + j]);
+ // make current B value zero
+ Bint8[row_indices[kidx] * n + j] = 0;
+ // std::cout << "(" << row_indices[kidx] << ", " << j << ")" <<
+ // endl;
+ }
+ }
+ B_csc.ColPtr()[n] = total_nnz;
+
+ // the top most (first) operation in the output processing
+ // pipeline is spmdm
+ // outType = final output type after fullly processing through pipeline
+ // inType = initial input type at the first call to the whole pipeline
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<false>::outType,
+ int32_t,
+ ReQuantizeOutput<false>>
+ spmdmObj(reqObj, Aint8.data(), k, B_csc);
+
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint32_mkl.data(), m, n, n, "C mkl");
+ ttot = 0;
+ runType = "FBGEMM_i8_acc16";
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ switch (bench_type) {
+ case BenchmarkType::BARE_BONE:
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ n,
+ memcopyObj,
+ 0,
+ 1);
+ break;
+ case BenchmarkType::REQUANTIZATION:
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ n,
+ reqObj,
+ 0,
+ 1);
+ break;
+ case BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION:
+ fbgemmPacked(
+ packAWithRowOffset,
+ packedB,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ n,
+ reqObj,
+ 0,
+ 1);
+ break;
+ case BenchmarkType::EVERYTHING:
+ fbgemmPacked(
+ packAWithRowOffset,
+ packedB,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ n,
+ spmdmObj,
+ 0,
+ 1);
+ break;
+ };
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+
+ ((volatile char*)(llc.data()));
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B
+ // unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(),
+ // m, n, n, "C requantized after");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint8_fb.data(), m, n, n, "C fb");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // col_offsets.data(), 1, n, n, "col offsets after");
+ // compare_buffers(row_offsets.data(), row_offset_buf.data(),
+ // row_offsets.size(), 5);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setw(16) << runType << ", " << fixed << setw(5) << setprecision(1)
+ << nops / ttot << endl;
+
+#ifdef USE_MKL
+ if (bench_type == BenchmarkType::BARE_BONE) {
+ compare_buffers(Cint32_mkl.data(), Cint32_fb.data(), m, n, n, 5);
+ } else {
+ compare_buffers(Cint8_mkl.data(), Cint8_fb.data(), m, n, n, 5);
+ }
+#endif
+ } // test_outlier
+ cout << endl;
+ } // shapes
+}
+
+int main() {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc
new file mode 100644
index 0000000..27f1433
--- /dev/null
+++ b/bench/PackedRequantizeAcc32Benchmark.cc
@@ -0,0 +1,329 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "test/QuantizationHelpers.h"
+#include "BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+void performance_test() {
+ vector<vector<int>> shapes = {
+ {156800, 4, 36},
+ {156800, 8, 36},
+ {156800, 16, 36},
+ {1, 128, 512},
+ {1, 1024, 256},
+ {1, 2048, 512},
+ {1, 4096, 1024},
+
+ {6, 256, 1024},
+ {6, 256, 2048},
+ {6, 512, 512},
+ {6, 1024, 256},
+ {6, 2048, 256},
+ {6, 2048, 512},
+ {6, 4096, 256},
+ {6, 4096, 1024},
+ {6, 4096, 2048},
+
+ {10, 2048, 256},
+ {10, 4096, 1024},
+
+ {20, 2048, 256},
+ {20, 4096, 1024},
+
+ {102, 1024, 512},
+ {102, 2323, 256},
+ {102, 512, 256},
+
+ {1, 800, 3200},
+ {1, 800, 8000},
+
+ {16, 256, 1500},
+ {16, 256, 1567},
+ {1, 128, 2876},
+ {16, 128, 1567},
+ {1, 128, 2722},
+ {16, 256, 512},
+ };
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "M, "
+ << "N, "
+ << "K, "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "Total (ms), "
+ << "GOPs" << endl;
+#else
+ cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, " << setw(18)
+ << "Type, " << setw(5) << "GOPS" << endl;
+#endif
+
+ chrono::time_point<chrono::high_resolution_clock> start, end;
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ float alpha = 1.f, beta = 0.f;
+ aligned_vector<float> Afp32(m * k, 0.0f);
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ aligned_vector<float> Bfp32(k * n, 0.0f);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+
+ aligned_vector<float> Cfp32_mkl(m * n, 0.0f);
+ aligned_vector<int32_t> Cint32_mkl(m * n, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<int32_t> Cint32_local(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+ aligned_vector<uint8_t> Cint8_local(m * n, 0);
+
+ // A matrix
+ randFill(Aint8, 0, 255);
+ // float Aint8_scale = 0.11;
+ int32_t Aint8_zero_point = 43;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Afp32[i] = (float)Aint8[i];
+ }
+
+ randFill(Bint8, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8.data());
+
+ // float Bint8_scale = 0.49;
+ int32_t Bint8_zero_point = -30;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bfp32[i] = (float)Bint8[i];
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data());
+
+ double nops = 2.0 * static_cast<double>(NITER) * m * n * k;
+ double ttot = 0.0;
+ string runType;
+#ifdef USE_MKL
+ runType = "MKL_fp32";
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+ llc_flush(llc);
+ start = chrono::high_resolution_clock::now();
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ CblasNoTrans,
+ m,
+ n,
+ k,
+ alpha,
+ Afp32.data(),
+ k,
+ Bfp32.data(),
+ n,
+ beta,
+ Cfp32_mkl.data(),
+ n);
+ end = chrono::high_resolution_clock::now();
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - start);
+ ttot += dur.count();
+ }
+ }
+ ((volatile char*)(llc.data()));
+
+ cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k << ", "
+ << setw(16) << runType << ", " << setw(5) << fixed << setw(5)
+ << setprecision(1) << nops / ttot << endl;
+
+ for (auto i = 0; i < Cfp32_mkl.size(); ++i) {
+ Cint32_mkl[i] = (int32_t)Cfp32_mkl[i];
+ }
+#endif
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ float C_multiplier = 0.1234;
+ int32_t C_zero_pt = 5;
+
+ matmul_u8i8acc32_ref(
+ m, n, k, k, n, n, Aint8.data(), Bint8.data(), Cint32_local.data());
+
+ row_offsets_u8acc32_ref(m, k, k, Aint8.data(), row_offsets.data());
+
+ requantize_u8acc32_ref(
+ m,
+ n,
+ n,
+ Cint32_local.data(),
+ Cint8_local.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ nullptr); // bias
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B
+ // unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_local.data(),
+ // m, n, n, "C int32");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint8_local.data(), m, n, n, "C requantized");
+ // printMatrix(matrix_op_t::NoTranspose, col_offsets.data(), 1, n, n, "col
+ // offsets before");
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(PackAWithRowOffset<uint8_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t> packedBN(
+ matrix_op_t::NoTranspose,
+ k,
+ n,
+ Bint8.data(),
+ n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false> outputProcObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ ttot = 0.0;
+ runType = "FBGEMM_i8_acc32";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ start = chrono::high_resolution_clock::now();
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint8_fb.data(),
+ Cint32_buffer.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - start);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+ ((volatile char*)(llc.data()));
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B
+ // unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k,
+ // "A unpacked");
+ // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(),
+ // m, n, n, "C requantized after");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint8_fb.data(), m, n, n, "C fb");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // col_offsets.data(), 1, n, n, "col offsets after");
+ // compare_buffers(row_offsets.data(), row_offset_buf.data(),
+ // row_offsets.size(), 5);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k << ", "
+ << setw(16) << runType << ", " << setw(5) << fixed << setw(5)
+ << setprecision(1) << nops / ttot << endl;
+ cout << endl;
+
+#ifdef USE_MKL
+ compare_buffers(Cint8_local.data(), Cint8_fb.data(), m, n, n, 5);
+#endif
+ }
+}
+
+int main(int /* unused */, char** /* unused */) {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/cmake/modules/DownloadASMJIT.cmake b/cmake/modules/DownloadASMJIT.cmake
new file mode 100644
index 0000000..6e6a6f9
--- /dev/null
+++ b/cmake/modules/DownloadASMJIT.cmake
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+project(asmjit-download NONE)
+
+include(ExternalProject)
+
+ExternalProject_Add(asmjit
+ GIT_REPOSITORY https://github.com/asmjit/asmjit
+ GIT_TAG master
+ SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/asmjit"
+ BINARY_DIR "${FBGEMM_BINARY_DIR}/asmjit"
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+)
diff --git a/cmake/modules/DownloadCPUINFO.cmake b/cmake/modules/DownloadCPUINFO.cmake
new file mode 100644
index 0000000..730ecbf
--- /dev/null
+++ b/cmake/modules/DownloadCPUINFO.cmake
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+project(cpuinfo-download NONE)
+
+include(ExternalProject)
+
+ExternalProject_Add(cpuinfo
+ GIT_REPOSITORY https://github.com/pytorch/cpuinfo
+ GIT_TAG master
+ SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo"
+ BINARY_DIR "${FBGEMM_BINARY_DIR}/cpuinfo"
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+)
diff --git a/cmake/modules/DownloadGTEST.cmake b/cmake/modules/DownloadGTEST.cmake
new file mode 100644
index 0000000..b0a56f8
--- /dev/null
+++ b/cmake/modules/DownloadGTEST.cmake
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+project(googletest-download NONE)
+
+include(ExternalProject)
+
+ExternalProject_Add(googletest
+ GIT_REPOSITORY https://github.com/google/googletest
+ GIT_TAG 0fc5466dbb9e623029b1ada539717d10bd45e99e
+ SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/googletest"
+ BINARY_DIR "${FBGEMM_BINARY_DIR}/googletest"
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+)
diff --git a/cmake/modules/FindMKL.cmake b/cmake/modules/FindMKL.cmake
new file mode 100644
index 0000000..6b38af7
--- /dev/null
+++ b/cmake/modules/FindMKL.cmake
@@ -0,0 +1,270 @@
+# - Find INTEL MKL library
+# This module finds the Intel Mkl libraries.
+# Note: This file is a modified version of pytorch/cmake/Modules/FindMKL.cmake
+#
+# This module sets the following variables:
+# MKL_FOUND - set to true if a library implementing the CBLAS interface is found
+# MKL_VERSION - best guess
+# MKL_INCLUDE_DIR - path to include dir.
+# MKL_LIBRARIES - list of libraries for base mkl
+
+# Do nothing if MKL_FOUND was set before!
+IF (NOT MKL_FOUND)
+
+SET(MKL_VERSION)
+SET(MKL_INCLUDE_DIR)
+SET(MKL_LIBRARIES)
+
+# Includes
+INCLUDE(CheckTypeSize)
+INCLUDE(CheckFunctionExists)
+
+# Intel Compiler Suite
+SET(INTEL_COMPILER_DIR "/opt/intel" CACHE STRING
+ "Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)")
+SET(INTEL_MKL_DIR "/opt/intel/mkl" CACHE STRING
+ "Root directory of the Intel MKL (standalone)")
+SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL
+ "Force using the sequential (non threaded) libraries")
+
+# Checks
+CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP)
+IF ("${SIZE_OF_VOIDP}" EQUAL 8)
+ SET(mklvers "intel64")
+ SET(iccvers "intel64")
+ SET(mkl64s "_lp64")
+ELSE ("${SIZE_OF_VOIDP}" EQUAL 8)
+ SET(mklvers "32")
+ SET(iccvers "ia32")
+ SET(mkl64s)
+ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8)
+IF(CMAKE_COMPILER_IS_GNUCC)
+ SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread")
+ SET(mklifaces "intel" "gf")
+ SET(mklrtls "gomp" "iomp5")
+ELSE(CMAKE_COMPILER_IS_GNUCC)
+ SET(mklthreads "mkl_intel_thread")
+ SET(mklifaces "intel")
+ SET(mklrtls "iomp5" "guide")
+ IF (MSVC)
+ SET(mklrtls "libiomp5md")
+ ENDIF (MSVC)
+ENDIF (CMAKE_COMPILER_IS_GNUCC)
+
+# Kernel libraries dynamically loaded
+SET(mklseq)
+
+
+# Paths
+SET(saved_CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH})
+SET(saved_CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH})
+IF(WIN32)
+ # Set default MKLRoot for Windows
+ IF($ENV{MKLProductDir})
+ SET(INTEL_COMPILER_DIR $ENV{MKLProductDir})
+ ELSE()
+ SET(INTEL_COMPILER_DIR
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows")
+ ENDIF()
+ # Change mklvers and iccvers when we are using MSVC instead of ICC
+ IF(MSVC AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
+ SET(mklvers "${mklvers}_win")
+ SET(iccvers "${iccvers}_win")
+ ENDIF()
+ENDIF(WIN32)
+IF (EXISTS ${INTEL_COMPILER_DIR})
+ # TODO: diagnostic if dir does not exist
+ SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
+ "${INTEL_COMPILER_DIR}/lib/${iccvers}")
+ IF(MSVC)
+ SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
+ "${INTEL_COMPILER_DIR}/compiler/lib/${iccvers}")
+ ENDIF()
+ IF (NOT EXISTS ${INTEL_MKL_DIR})
+ SET(INTEL_MKL_DIR "${INTEL_COMPILER_DIR}/mkl")
+ ENDIF()
+ENDIF()
+IF (EXISTS ${INTEL_MKL_DIR})
+ # TODO: diagnostic if dir does not exist
+ SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}
+ "${INTEL_MKL_DIR}/include")
+ SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
+ "${INTEL_MKL_DIR}/lib/${mklvers}")
+ IF (MSVC)
+ SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
+ "${INTEL_MKL_DIR}/lib/${iccvers}")
+ ENDIF()
+ENDIF()
+
+# Try linking multiple libs
+MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags)
+ # This macro checks for the existence of the combination of libraries given by _list.
+ # If the combination is found, this macro checks whether we can link against that library
+ # combination using the name of a routine given by _name using the linker
+ # flags given by _flags. If the combination of libraries is found and passes
+ # the link test, LIBRARIES is set to the list of complete library paths that
+ # have been found. Otherwise, LIBRARIES is set to FALSE.
+ # N.B. _prefix is the prefix applied to the names of all cached variables that
+ # are generated internally and marked advanced by this macro.
+ SET(_prefix "${LIBRARIES}")
+ # start checking
+ SET(_libraries_work TRUE)
+ SET(${LIBRARIES})
+ SET(_combined_name)
+ SET(_paths)
+ SET(__list)
+ FOREACH(_elem ${_list})
+ IF(__list)
+ SET(__list "${__list} - ${_elem}")
+ ELSE(__list)
+ SET(__list "${_elem}")
+ ENDIF(__list)
+ ENDFOREACH(_elem)
+ MESSAGE(STATUS "Checking for [${__list}]")
+ FOREACH(_library ${_list})
+ SET(_combined_name ${_combined_name}_${_library})
+ IF(_libraries_work)
+ IF(${_library} STREQUAL "gomp")
+ FIND_PACKAGE(OpenMP)
+ IF(OPENMP_FOUND)
+ SET(${_prefix}_${_library}_LIBRARY ${OpenMP_C_FLAGS})
+ ENDIF(OPENMP_FOUND)
+ ELSE(${_library} STREQUAL "gomp")
+ FIND_LIBRARY(${_prefix}_${_library}_LIBRARY NAMES ${_library})
+ ENDIF(${_library} STREQUAL "gomp")
+ MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY)
+ SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
+ SET(_libraries_work ${${_prefix}_${_library}_LIBRARY})
+ IF(${_prefix}_${_library}_LIBRARY)
+ MESSAGE(STATUS " Library ${_library}: ${${_prefix}_${_library}_LIBRARY}")
+ ELSE(${_prefix}_${_library}_LIBRARY)
+ MESSAGE(STATUS " Library ${_library}: not found")
+ ENDIF(${_prefix}_${_library}_LIBRARY)
+ ENDIF(_libraries_work)
+ ENDFOREACH(_library ${_list})
+ # Test this combination of libraries.
+ IF(_libraries_work)
+ SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}})
+ SET(CMAKE_REQUIRED_LIBRARIES "${CMAKE_REQUIRED_LIBRARIES};${CMAKE_REQUIRED_LIBRARIES}")
+ CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS)
+ SET(CMAKE_REQUIRED_LIBRARIES)
+ MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS)
+ SET(_libraries_work ${${_prefix}${_combined_name}_WORKS})
+ ENDIF(_libraries_work)
+ # Fin
+ IF(_libraries_work)
+ ELSE (_libraries_work)
+ SET(${LIBRARIES})
+ MARK_AS_ADVANCED(${LIBRARIES})
+ ENDIF(_libraries_work)
+ENDMACRO(CHECK_ALL_LIBRARIES)
+
+IF(WIN32)
+ SET(mkl_m "")
+ SET(mkl_pthread "")
+ELSE(WIN32)
+ SET(mkl_m "m")
+ SET(mkl_pthread "pthread")
+ENDIF(WIN32)
+
+IF(UNIX AND NOT APPLE)
+ SET(mkl_dl "${CMAKE_DL_LIBS}")
+ELSE(UNIX AND NOT APPLE)
+ SET(mkl_dl "")
+ENDIF(UNIX AND NOT APPLE)
+
+# Check for version 10/11
+IF (NOT MKL_LIBRARIES)
+ SET(MKL_VERSION 1011)
+ENDIF (NOT MKL_LIBRARIES)
+FOREACH(mklrtl ${mklrtls} "")
+ FOREACH(mkliface ${mklifaces})
+ FOREACH(mkl64 ${mkl64s} "")
+ FOREACH(mklthread ${mklthreads})
+ IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
+ CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
+ "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};${mkl_pthread};${mkl_m};${mkl_dl}" "")
+ ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
+ ENDFOREACH(mklthread)
+ ENDFOREACH(mkl64)
+ ENDFOREACH(mkliface)
+ENDFOREACH(mklrtl)
+FOREACH(mklrtl ${mklrtls} "")
+ FOREACH(mkliface ${mklifaces})
+ FOREACH(mkl64 ${mkl64s} "")
+ IF (NOT MKL_LIBRARIES)
+ CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
+ "mkl_${mkliface}${mkl64};mkl_sequential;mkl_core;${mkl_m};${mkl_dl}" "")
+ IF (MKL_LIBRARIES)
+ SET(mklseq "_sequential")
+ ENDIF (MKL_LIBRARIES)
+ ENDIF (NOT MKL_LIBRARIES)
+ ENDFOREACH(mkl64)
+ ENDFOREACH(mkliface)
+ENDFOREACH(mklrtl)
+FOREACH(mklrtl ${mklrtls} "")
+ FOREACH(mkliface ${mklifaces})
+ FOREACH(mkl64 ${mkl64s} "")
+ FOREACH(mklthread ${mklthreads})
+ IF (NOT MKL_LIBRARIES)
+ CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
+ "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
+ ENDIF (NOT MKL_LIBRARIES)
+ ENDFOREACH(mklthread)
+ ENDFOREACH(mkl64)
+ ENDFOREACH(mkliface)
+ENDFOREACH(mklrtl)
+
+# Check for older versions
+IF (NOT MKL_LIBRARIES)
+ SET(MKL_VERSION 900)
+ CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
+ "mkl;guide;pthread;m" "")
+ENDIF (NOT MKL_LIBRARIES)
+
+# Include files
+IF (MKL_LIBRARIES)
+ FIND_PATH(MKL_INCLUDE_DIR "mkl_cblas.h")
+ MARK_AS_ADVANCED(MKL_INCLUDE_DIR)
+ENDIF (MKL_LIBRARIES)
+
+# LibIRC: intel compiler always links this;
+# gcc does not; but mkl kernels sometimes need it.
+IF (MKL_LIBRARIES)
+ IF (CMAKE_COMPILER_IS_GNUCC)
+ FIND_LIBRARY(MKL_KERNEL_libirc "irc")
+ ELSEIF (CMAKE_C_COMPILER_ID AND NOT CMAKE_C_COMPILER_ID STREQUAL "Intel")
+ FIND_LIBRARY(MKL_KERNEL_libirc "irc")
+ ENDIF (CMAKE_COMPILER_IS_GNUCC)
+ MARK_AS_ADVANCED(MKL_KERNEL_libirc)
+ IF (MKL_KERNEL_libirc)
+ SET(MKL_LIBRARIES ${MKL_LIBRARIES} ${MKL_KERNEL_libirc})
+ ENDIF (MKL_KERNEL_libirc)
+ENDIF (MKL_LIBRARIES)
+
+# Final
+SET(CMAKE_LIBRARY_PATH ${saved_CMAKE_LIBRARY_PATH})
+SET(CMAKE_INCLUDE_PATH ${saved_CMAKE_INCLUDE_PATH})
+IF (MKL_LIBRARIES AND MKL_INCLUDE_DIR)
+ SET(MKL_FOUND TRUE)
+ set(MKL_cmake_included true)
+ELSE (MKL_LIBRARIES AND MKL_INCLUDE_DIR)
+ SET(MKL_FOUND FALSE)
+ SET(MKL_VERSION)
+ENDIF (MKL_LIBRARIES AND MKL_INCLUDE_DIR)
+
+# Standard termination
+IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
+ MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
+ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
+IF(NOT MKL_FIND_QUIETLY)
+ IF(MKL_FOUND)
+ MESSAGE(STATUS "MKL library found")
+ ELSE(MKL_FOUND)
+ MESSAGE(STATUS "MKL library not found")
+ return()
+ ENDIF(MKL_FOUND)
+ENDIF(NOT MKL_FIND_QUIETLY)
+
+# Do nothing if MKL_FOUND was set before!
+ENDIF (NOT MKL_FOUND)
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
new file mode 100644
index 0000000..02e862f
--- /dev/null
+++ b/include/fbgemm/ConvUtils.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <string>
+
+namespace fbgemm2 {
+
+/**
+ * @brief A struct to conveniently store all convolution parameters.
+ */
+struct conv_param_t {
+ int MB; ///< Mini Batch size
+ int IC; ///< Number of Input Channels
+ int OC; ///< Number of Output Channels
+ int IH; ///< Input Image Height
+ int IW; ///< Input Image Width
+ int G; ///< Number of Groups
+ int KH; ///< Filter (Kernel) Height
+ int KW; ///< Filter (Kernel) Width
+ int stride_h; ///< Stride in Height Dimension
+ int stride_w; ///< Stride in Width Dimension
+ int pad_h; ///< Padding in Height Dimension (top and bottom)
+ int pad_w; ///< Padding in Width Dimension (left and right)
+ int dilation_h; ///< Kernel dilation in Height Dimension
+ int dilation_w; ///< Kernel dilation in Width Dimension
+
+ // The following are derived parameters
+ int OH; ///< Output Image Height
+ int OW; ///< Output Image Width
+ int IHP; ///< Input Height Padded
+ int IWP; ///< Input Width Padded
+
+ /**
+ * @brief Constructor for initializing the convolution parameters.
+ * TODO: Dilation is not handled correctly.
+ */
+ conv_param_t(
+ int mb,
+ int ic,
+ int oc,
+ int ih,
+ int iw,
+ int g = 1,
+ int kh = 3,
+ int kw = 3,
+ int strd_h = 1,
+ int strd_w = 1,
+ int pd_h = 1,
+ int pd_w = 1)
+ : MB(mb),
+ IC(ic),
+ OC(oc),
+ IH(ih),
+ IW(iw),
+ G(g),
+ KH(kh),
+ KW(kw),
+ stride_h(strd_h),
+ stride_w(strd_w),
+ pad_h(pd_h),
+ pad_w(pd_w),
+ dilation_h(1),
+ dilation_w(1) {
+ IHP = IH + 2 * pad_h;
+ IWP = IW + 2 * pad_w;
+ OH = (IHP - KH) / stride_h + 1;
+ OW = (IWP - KW) / stride_w + 1;
+ }
+
+ /**
+ * @brief Helper function to get convolution parameters as string.
+ */
+ std::string toString() const {
+ std::string out = "";
+ out += "MB:" + std::to_string(MB) + ", ";
+ out += "IC:" + std::to_string(IC) + ", ";
+ out += "OC:" + std::to_string(OC) + ", ";
+ out += "IH:" + std::to_string(IH) + ", ";
+ out += "IW:" + std::to_string(IW) + ", ";
+ out += "G:" + std::to_string(G) + ", ";
+ out += "KH:" + std::to_string(KH) + ", ";
+ out += "KW:" + std::to_string(KW) + ", ";
+ out += "stride_h:" + std::to_string(stride_h) + ", ";
+ out += "stride_w:" + std::to_string(stride_w) + ", ";
+ out += "pad_h:" + std::to_string(pad_h) + ", ";
+ out += "pad_w:" + std::to_string(pad_w);
+ return out;
+ }
+};
+
+} // namespace fbgemm2
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
new file mode 100644
index 0000000..988e24b
--- /dev/null
+++ b/include/fbgemm/Fbgemm.h
@@ -0,0 +1,952 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+/**
+ * Top level include file for FBGEMM.
+ */
+#include <immintrin.h>
+#include <cassert>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include "ConvUtils.h"
+#include "FbgemmI8Spmdm.h"
+#include "Types.h"
+#include "Utils.h"
+
+// #define FBGEMM_MEASURE_TIME_BREAKDOWN
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+#include <chrono>
+#include <iostream>
+extern double packing_time;
+extern double computing_time;
+extern double kernel_time;
+extern double postprocessing_time;
+extern double run_time;
+#endif
+
+namespace fbgemm2 {
+
+/**
+ * @brief Templatized struct for packing parameters for A and B matrices.
+ *
+ * @tparam T input type
+ * @tparam accT the type used for accumulation
+ * @tparam instSet anyarch/avx2/avx512
+ * @tparam int8Type an auxiliary template parameter to specialize for 8-bit
+ * input types.
+ */
+template <
+ typename T,
+ typename accT,
+ inst_set_t instSet,
+ typename int8Type = void>
+struct PackingTraits;
+
+// type specialized implementation in an include file
+#include "PackingTraits-inl.h"
+
+/**
+ * @brief Base class for packing matrices for higher GEMM performance.
+ *
+ * Matrix is tiled into blockRows() * blockCols() blocks.
+ * Each block is with size blockRowSize() * blockColSize().
+ * This class is designed using CRTP
+ * (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
+ *
+ * @tparam PT actual packing type, e.g., PackAWithRowOffset
+ */
+template <typename PT, typename inpType, typename accType = std::int32_t>
+class PackMatrix {
+ public:
+ PackMatrix() = delete; // no default constructor
+
+ /**
+ * @param rows total number of rows in the matrix
+ * (packed rows can be less than rows).
+ * @param cols total number of columns in the matrix
+ * @param pmat A buffer to contain the packed matrix.
+ * If nullptr, a buffer owned by PackMatrix will be allocated
+ * internally to contain the packed matrix.
+ * For non-constant matrices like activation matrices, the client
+ * code may want to pass a pre-allocated pmat to avoid the
+ * overhead of internal memory allocation everytime a PackMatrix
+ * is constructed. The client code can query how big patm should
+ * be with packedBufferSize function.
+ * @param zero_pt the quantized value that maps to 0.0f floating-point number.
+ */
+ PackMatrix(
+ std::int32_t rows,
+ std::int32_t cols,
+ inpType* pmat,
+ std::int32_t zero_pt);
+
+ /**
+ * @return true usually when the matrix is constant matrix (e.g., weight
+ * matrices) that can be prepacked
+ */
+ bool isPrePacked() const {
+ return static_cast<const PT*>(this)->isPrePacked();
+ }
+
+ /**
+ * @return true if this is the first input matrix in GEMM (i.e., A in C = A *
+ * B)
+ */
+ static constexpr bool isA() {
+ return PT::isA();
+ }
+
+ /**
+ * @brief The size of the buffer used for packing (The size is in number of
+ * elements).
+ *
+ * rows and cols are only used for fully packing, i.e., for B matrix. The
+ * client code can use this function to query how big the buffer used for
+ * packing should be.
+ */
+ static int packedBufferSize(int rows = 0, int cols = 0);
+
+ /**
+ * @return Pointer to a buffer containing row offset results. Some packing
+ * objects fuse row offset computation for later requantization step.
+ */
+ std::int32_t* getRowOffsetBuffer() const {
+ return static_cast<const PT*>(this)->getRowOffsetBuffer();
+ }
+
+ /**
+ * @brief When k loop is also tiled/blocked, this function is used to check if
+ * have executed computations for the last k block so that we can perform
+ * post-GEMM operations.
+ */
+ bool isThisLastKBlock(int block_id) const {
+ return static_cast<const PT*>(this)->isThisLastKBlock(block_id);
+ }
+
+ /**
+ * @brief Actual packing of a block of the source matrix in pmat buffer.
+ */
+ void pack(const block_type_t& block) {
+ static_cast<PT*>(this)->pack(block);
+ }
+
+ std::int32_t numRows() const {
+ return nrows_;
+ }
+
+ std::int32_t numCols() const {
+ return ncols_;
+ }
+
+ /**
+ * @return The number of rows in each block
+ */
+ std::int32_t blockRowSize() const {
+ return brow_;
+ }
+
+ /**
+ * @return The number of columns in each block
+ */
+ std::int32_t blockColSize() const {
+ return bcol_;
+ }
+
+ /**
+ * @return The number of blocks along rows
+ */
+ std::int32_t blockRows() const {
+ return nbrow_;
+ }
+
+ /**
+ * @return The number of blocks along columns
+ */
+ std::int32_t blockCols() const {
+ return nbcol_;
+ }
+
+ /**
+ * @return The number of the rows in the currently packed block of a matrix.
+ * For pre-packed (i.e., fully-packed), it's equal to the total number
+ * of rows.
+ */
+ std::int32_t numPackedRows() const {
+ return packedBlock_.row_size;
+ }
+
+ /**
+ * @return The number of columns in the currently packed block of a matrix.
+ * For pre-packed (i.e., fully-packed), it's equal to the number of
+ * columns.
+ */
+ std::int32_t numPackedCols() const {
+ return packedBlock_.col_size;
+ }
+
+ /**
+ * @return The first row of the block we're working on.
+ */
+ std::int32_t packedRowStart() const {
+ return packedBlock_.row_start;
+ }
+
+ /**
+ * @return The beginning of (rowBlockNum, colBlockNum)th block
+ */
+ inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) {
+ return buf_ + blockRowSize() * blockColSize() * rowBlockNum +
+ blockRowSize() * blockColSize() * blockCols() * colBlockNum;
+ }
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name) {
+ static_cast<PT*>(this)->printPackedMatrix(name);
+ }
+
+ /**
+ * @return The number of rows in the last row block.
+ */
+ std::int32_t lastBrow() const {
+ return last_brow_;
+ }
+
+ /**
+ * @return The number of columns in the last column block.
+ */
+ std::int32_t lastBcol() const {
+ return last_bcol_;
+ }
+
+ /**
+ * @return True if the last column block has fewer columns than the block
+ * size.
+ */
+ bool isThereColRemainder() const {
+ return last_bcol_ != blockColSize();
+ }
+
+ ~PackMatrix() {
+ if (bufAllocatedHere_) {
+ free(buf_);
+ }
+ }
+
+ protected:
+ /**
+ * Set which block we're packing
+ */
+ void packedBlock(const block_type_t& block) {
+ packedBlock_ = block;
+ nbrow_ = (numPackedRows() + blockRowSize() - 1) / blockRowSize();
+ nbcol_ = (numPackedCols() + blockColSize() - 1) / blockColSize();
+
+ last_brow_ = ((numPackedRows() % blockRowSize()) == 0)
+ ? blockRowSize()
+ : (numPackedRows() % blockRowSize());
+ last_bcol_ = ((numPackedCols() % blockColSize()) == 0)
+ ? blockColSize()
+ : (numPackedCols() % blockColSize());
+ }
+
+ /**
+ * @return the quantized value that maps to 0.0f floating-point number
+ */
+ std::int32_t zeroPoint() const {
+ return zero_pt_;
+ }
+
+ inpType* buf_;
+ std::int32_t brow_; ///< the number of rows in each block
+ std::int32_t bcol_; ///< the number of columns in each block
+ std::int32_t nbrow_; ///< the number of blocks along rows
+ std::int32_t nbcol_; ///< the number of blocks along columns
+ bool bufAllocatedHere_;
+
+ private:
+ std::int32_t nrows_, ncols_;
+ std::int32_t zero_pt_;
+ block_type_t packedBlock_; ///< The block in the source matrix just packed
+ std::int32_t last_brow_, last_bcol_;
+};
+
+/**
+ * @brief Matrix packed for the first input matrix in GEMM (usually
+ * activation). The source matrix is already quantized. Default
+ * accumulation type is int32.
+ */
+template <typename T, typename accT = std::int32_t>
+class PackAMatrix : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
+ public:
+ using This = PackAMatrix<T, accT>;
+ using BaseType = PackMatrix<This, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackAMatrix() = delete; // no default constructor
+
+ /**
+ * TODO: currently only groups == 1 supported
+ */
+ PackAMatrix(
+ matrix_op_t trans,
+ std::int32_t nRow,
+ std::int32_t nCol,
+ const inpType* smat,
+ std::int32_t ld,
+ inpType* pmat = nullptr,
+ std::int32_t groups = 1,
+ accT zero_pt = 0);
+
+ /**
+ * Activation matrices are not constant so cannot amortize the cost of
+ * pre-packing.
+ */
+ bool isPrePacked() const {
+ return false;
+ }
+
+ /**
+ * @return True is this is used as A matrix.
+ */
+ static constexpr bool isA() {
+ return true;
+ }
+
+ /**
+ * @return A pointer to the row offset buffer. There is no row offset buffer
+ * calculations with this packing class, hence, it returns nullptr.
+ */
+ std::int32_t* getRowOffsetBuffer() const {
+ return nullptr;
+ }
+
+ /**
+ * @return Offset of the element in the packed matrix that was at (i, j) in
+ * the source matrix.
+ */
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack(const block_type_t& block);
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name);
+
+ private:
+ matrix_op_t trans_;
+ const T* smat_;
+ std::int32_t ld_;
+ std::int32_t G_;
+ std::int32_t row_interleave_B_;
+};
+
+/**
+ * @brief Matrix packed for the second input matrix in GEMM (usually weight).
+ * The source matrix is already quantized. Default accumulation
+ * type is int32.
+ */
+template <typename T, typename accT = std::int32_t>
+class PackBMatrix : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
+ public:
+ using This = PackBMatrix<T, accT>;
+ using BaseType = PackMatrix<This, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackBMatrix() = delete; // no default constructor
+
+ /**
+ * TODO: Currently only groups == 1 supported.
+ */
+ PackBMatrix(
+ matrix_op_t trans,
+ std::int32_t nRow,
+ std::int32_t nCol,
+ const inpType* smat,
+ std::int32_t ld,
+ inpType* pmat = nullptr,
+ std::int32_t groups = 1,
+ accT zero_pt = 0);
+
+ /**
+ * Weight matrices are usually constant so worth pre-packing.
+ */
+ bool isPrePacked() const {
+ return true;
+ }
+
+ /**
+ * @return True if to be used as A matrix, False otherwise.
+ */
+ static constexpr bool isA() {
+ return false;
+ }
+
+ /**
+ * @brief When k loop is also tiled/blocked, this function is used to check if
+ * have executed computations for the last k block so that we can perform
+ * post-GEMM operations.
+ */
+ bool isThisLastKBlock(int block_id) const {
+ return (BaseType::blockRows() - 1) == block_id;
+ }
+
+ /**
+ * @return Offset of the element in the packed matrix that was at (i, j) in
+ * the source matrix.
+ */
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack(const block_type_t& block);
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name);
+
+ ~PackBMatrix() {}
+
+ private:
+ matrix_op_t trans_;
+ const T* smat_;
+ std::int32_t ld_;
+ std::int32_t G_;
+ std::int32_t row_interleave_;
+};
+
+/**
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
+ * and row offsets used for requantization is computed during packing.
+ * Im2col is fused with packing here. The source matrix is already
+ * quantized.
+ */
+template <typename T, typename accT = std::int32_t>
+class PackAWithIm2Col : public PackMatrix<PackAWithIm2Col<T, accT>, T, accT> {
+ public:
+ using This = PackAWithIm2Col<T, accT>;
+ using BaseType = PackMatrix<This, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackAWithIm2Col() = delete; // no default constructor
+ /**
+ * TODO: Currently only groups == 1 supported
+ */
+ PackAWithIm2Col(
+ const conv_param_t& conv_param,
+ const T* sdata,
+ inpType* pmat = nullptr,
+ std::int32_t zero_pt = 0,
+ std::int32_t* row_offset = nullptr);
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack(const block_type_t& block);
+
+ /**
+ * @return A pointer to the row offset buffer.
+ */
+ std::int32_t* getRowOffsetBuffer() const {
+ return row_offset_;
+ }
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name);
+
+ /**
+ * @return Size of row offset buffer in number of elements
+ */
+ static int rowOffsetBufferSize();
+
+ ~PackAWithIm2Col() {
+ if (rowOffsetAllocatedHere) {
+ free(row_offset_);
+ }
+ }
+
+ private:
+ const conv_param_t& conv_p_;
+ const T* sdata_;
+ std::int32_t* row_offset_;
+ bool rowOffsetAllocatedHere;
+ std::int32_t row_interleave_B_;
+};
+
+/**
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
+ * and row offsets used for requantization is computed during packing.
+ * The source matrix is already quantized.
+ */
+template <typename T, typename accT = std::int32_t>
+class PackAWithRowOffset
+ : public PackMatrix<PackAWithRowOffset<T, accT>, T, accT> {
+ public:
+ using This = PackAWithRowOffset<T, accT>;
+ using BaseType = PackMatrix<This, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackAWithRowOffset() = delete; // no default constructor
+ /**
+ * TODO: Currently only groups == 1 supported
+ */
+ PackAWithRowOffset(
+ matrix_op_t trans,
+ std::uint32_t nRow,
+ std::uint32_t nCol,
+ const T* smat,
+ std::uint32_t ld,
+ inpType* pmat = nullptr,
+ std::uint32_t groups = 1,
+ std::int32_t zero_pt = 0,
+ std::int32_t* row_offset = nullptr);
+
+ /**
+ * @return Offset of the element in the packed matrix that was at (i, j) in
+ * the source matrix
+ */
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack(const block_type_t& block);
+
+ /**
+ * @return A pointer to the row offset buffer.
+ */
+ std::int32_t* getRowOffsetBuffer() const {
+ return row_offset_;
+ }
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name);
+
+ /**
+ * @return size of row offset buffer in number of elements
+ */
+ static int rowOffsetBufferSize();
+
+ ~PackAWithRowOffset() {
+ if (rowOffsetAllocatedHere) {
+ free(row_offset_);
+ }
+ }
+
+ private:
+ matrix_op_t trans_;
+ const T* smat_;
+ std::uint32_t ld_;
+ std::uint32_t G_;
+ std::int32_t* row_offset_;
+ bool rowOffsetAllocatedHere;
+ std::int32_t row_interleave_B_;
+};
+
+/**
+ * @brief Matrix packed for the first input matrix in GEMM (usually activation),
+ * and row offsets used for requantization is computed during packing.
+ * The source matrix is in fp32 and quantized during packing.
+ */
+template <typename T, typename accT = std::int32_t>
+class PackAWithQuantRowOffset
+ : public PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT> {
+ public:
+ using This = PackAWithQuantRowOffset<T, accT>;
+ using BaseType = PackMatrix<This, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackAWithQuantRowOffset() = delete; // no default constructor
+ /**
+ * TODO: Currently only groups == 1 supported
+ */
+ PackAWithQuantRowOffset(
+ matrix_op_t trans,
+ std::int32_t nRow,
+ std::int32_t nCol,
+ const float* smat,
+ std::int32_t ld,
+ inpType* pmat = nullptr,
+ float scale = 1.0f,
+ std::int32_t zero_pt = 0,
+ std::int32_t groups = 1,
+ std::int32_t* row_offset = nullptr);
+
+ /**
+ * @return offset of the element in the packed matrix that was at (i, j) in
+ * the source matrix
+ */
+ std::int32_t addr(std::int32_t i, std::int32_t j) const;
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack(const block_type_t& block);
+
+ /**
+ * @return A pointer to the row offset buffer.
+ */
+ std::int32_t* getRowOffsetBuffer() const {
+ return row_offset_;
+ }
+
+ /**
+ * @brief Print the packed block.
+ */
+ void printPackedMatrix(std::string name);
+
+ /**
+ * @return Size of row offset buffer in number of elements
+ */
+ static int rowOffsetBufferSize();
+
+ ~PackAWithQuantRowOffset() {
+ if (rowOffsetAllocatedHere) {
+ free(row_offset_);
+ }
+ }
+
+ private:
+ matrix_op_t trans_;
+ const float* smat_;
+ std::int32_t ld_;
+ float scale_;
+ std::int32_t G_;
+ std::int32_t* row_offset_;
+ bool rowOffsetAllocatedHere;
+ std::int32_t row_interleave_B_;
+};
+
+/*
+ *
+ * Post Processing of outputs
+ *
+ */
+
+/**
+ * @brief Does nothing. NoOp. Used as the last operation in the output
+ * processing pipeline.
+ *
+ */
+template <typename outT = std::uint8_t, typename inT = std::uint8_t>
+class DoNothing {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ DoNothing() {}
+ template <inst_set_t instSet>
+ int f(
+ outType* /* unused */,
+ inpType* /* unused */,
+ const block_type_t& /* unused */,
+ int /* unused */,
+ int /* unused */) const {
+ return 0;
+ }
+};
+
+/**
+ * @brief Copy data pointed by inp ptr to out ptr when
+ * inp ptr and out ptr are not the same.
+ * inp buffer: row and column start points: (0, 0)
+ * output buffer: row and column start points:
+ * (block.row_start, block.col_start)
+ *
+ * This is the output processing stage that should passed when there is no
+ * requantization and output is required in the same format as internal buffer
+ * used for accumulation.
+ */
+template <
+ typename outT = std::int32_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<outT, outT>>
+class memCopy {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ explicit memCopy(nextOPType& nextop) : nextop_(nextop) {}
+ template <inst_set_t instSet>
+ inline int f(
+ outType* out,
+ inpType* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ nextOPType& nextop_;
+};
+
+/**
+ * @brief Perform scaling on accumulated data.
+ */
+template <
+ typename outT = std::int32_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<outT, outT>>
+class ScaleOP {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ explicit ScaleOP(inpType scalingFactor) : scalingFactor_(scalingFactor) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outType* out,
+ inpType* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ inpType scalingFactor_;
+};
+
+/**
+ * @brief Perform Relu on accumulated data.
+ */
+template <
+ typename outT = std::int32_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<outT, outT>>
+class ReluOutput {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ explicit ReluOutput(inpType zero_pt) : zero_pt_(zero_pt) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outType* out,
+ inpType* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ inpType zero_pt_;
+};
+
+/**
+ * @brief Perform Sparse-Matrix * Dense-Matrix as a part the of output
+ * processing pipeline.
+ *
+ * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
+ * (inp). After modifying the input buffer, pass it to the next op
+ */
+template <
+ typename outT = std::int32_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<inT, inT>>
+class DoSpmdmOnInpBuffer {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ DoSpmdmOnInpBuffer(
+ nextOPType& nextop,
+ const std::uint8_t* A,
+ int lda,
+ const CompressedSparseColumn& B_csc)
+ : nextop_(nextop), A_(A), lda_(lda), B_csc_(B_csc) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ nextOPType& nextop_;
+ const std::uint8_t* A_;
+ const int lda_;
+ const CompressedSparseColumn& B_csc_;
+};
+
+/**
+ * @brief Requantize values in inp buffer and write to out buffer.
+ * pass the out buffer to next op for further processing.
+ *
+ */
+template <
+ bool FUSE_RELU,
+ typename outT = std::uint8_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<outT, outT>>
+class ReQuantizeOutput {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ ReQuantizeOutput(
+ nextOPType& nextop,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t Aq_zero_point,
+ std::int32_t Bq_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias)
+ : nextop_(nextop),
+ C_multiplier_(C_multiplier),
+ C_zero_point_(C_zero_point),
+ Aq_zero_point_(Aq_zero_point),
+ Bq_zero_point_(Bq_zero_point),
+ q_row_offsets_(row_offsets),
+ q_col_offsets_(col_offsets),
+ bias_(bias) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ nextOPType& nextop_;
+ float C_multiplier_;
+ std::int32_t C_zero_point_;
+ std::int32_t Aq_zero_point_;
+ std::int32_t Bq_zero_point_;
+ const std::int32_t* q_row_offsets_;
+ const std::int32_t* q_col_offsets_;
+ const std::int32_t* bias_;
+};
+
+/**
+ * @brief Requantize to convert accumulated data to be used as float, i.e., the
+ * output would be used as float.
+ */
+template <
+ bool FUSE_RELU,
+ typename outT = float,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<outT, outT>>
+class ReQuantizeForFloat {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ ReQuantizeForFloat(
+ nextOPType& nextop,
+ float Aq_scale,
+ float Bq_scale,
+ std::int32_t Aq_zero_point,
+ std::int32_t Bq_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const float* bias)
+ : nextop_(nextop),
+ Aq_scale_(Aq_scale),
+ Bq_scale_(Bq_scale),
+ Aq_zero_point_(Aq_zero_point),
+ Bq_zero_point_(Bq_zero_point),
+ q_row_offsets_(row_offsets),
+ q_col_offsets_(col_offsets),
+ bias_(bias) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ nextOPType& nextop_;
+ float Aq_scale_, Bq_scale_;
+ std::int32_t Aq_zero_point_;
+ std::int32_t Bq_zero_point_;
+ const std::int32_t* q_row_offsets_;
+ const std::int32_t* q_col_offsets_;
+ const float* bias_;
+};
+
+// type specialized implementation in an include file
+#ifdef __AVX2__
+#include "OutputProcessing-inl.h"
+#endif
+
+/*
+ *
+ * ####### GEMM related functions #######
+ *
+ */
+
+/**
+ * Matrix B must be prepacked. For matrix A, packA.pack function is called to
+ * pack it.
+ *
+ * @tparam packingAMatrix processing of A matrix while packing,
+ * e.g., PackAWithQuantRowOffset
+ *
+ * @tparam packingBMatrix processing of B matrix while packing,
+ * e.g., pre-multiply by alpha
+ * @tparam cT data type of C matrix
+ * @tparam processOutputType further processing of outputs, e.g., Relu
+ */
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename cT,
+ typename processOutputType>
+void fbgemmPacked(
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>& packA,
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packB,
+ cT* C,
+ std::int32_t* C_buffer,
+ std::uint32_t ldc,
+ const processOutputType& outProcess,
+ int thread_id,
+ int num_threads);
+
+/**
+ * @brief Perform depthwise separable convolution
+ */
+
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename outT,
+ typename processOutputType>
+void convDepthwiseSeparable(
+ const conv_param_t& conv_param_dw,
+ const conv_param_t& conv_param_1x1,
+ packingAMatrix& packdw,
+ packingBMatrix& packed_1x1,
+ outT* out,
+ const processOutputType& output);
+
+} // namespace fbgemm2
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
new file mode 100644
index 0000000..55718d4
--- /dev/null
+++ b/include/fbgemm/FbgemmFP16.h
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
+// upgraded to match with new fbgemm interface.
+
+#include <cassert>
+#include <memory>
+#include <vector>
+
+#include "Types.h"
+#include "Utils.h"
+
+namespace fbgemm2 {
+
+/// class that performs packing of matrix in
+/// row-major format into
+/// internal packed blocked-row major format
+class PackedGemmMatrixFP16 {
+public:
+ // takes smat input mamtrix in row-major format;
+ // and packs it into gemm-friendly blocked format;
+ // allocate space and sets up all the internal variables;
+ // also premultiplies by alpha during packing
+ // brow_ contains tile size along k dimension
+ // and also is # of fmas updates into int16 container
+ // before flushing into fp32
+ // the smaller the brow_, the higher overhead
+ // of flushing is
+ PackedGemmMatrixFP16(const matrix_op_t trans, const int nrow,
+ const int ncol, const float alpha,
+ const float *smat,
+ const int brow = 512)
+ : nrow_(nrow), ncol_(ncol), brow_(brow) {
+
+ bcol_ = 8 * 1; // hardwired
+
+ // set up internal packing parameters
+ nbrow_ = ((numRows() % blockRowSize()) == 0)
+ ? (numRows() / blockRowSize())
+ : ((numRows() + blockRowSize()) / blockRowSize());
+ last_brow_ = ((nrow % blockRowSize()) == 0) ? blockRowSize()
+ : (nrow % blockRowSize());
+ nbcol_ = ((numCols() % blockColSize()) == 0)
+ ? (numCols() / blockColSize())
+ : ((numCols() + blockColSize()) / blockColSize());
+
+ if (numCols() != blockColSize() * nbcol_) {
+#ifdef VLOG
+ VLOG(0)
+ << "Packer warning: ncol(" << numCols()
+ << ") is not a multiple of internal block size (" << blockColSize()
+ << ")";
+ VLOG(0)
+ << "lefover is currently done via MKL: hence overhead will inccur";
+#endif
+ }
+
+ // allocate and initialize packed memory
+ const int padding = 1024; // required by sw pipelined kernels
+ size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_);
+ pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + padding);
+ for (auto i = 0; i < matSize(); i++) {
+ pmat_[i] = tconv(0.f, pmat_[i]);
+ }
+
+ // copy source matrix into packed matrix
+ this->packFromSrc(trans, alpha, smat);
+ }
+
+ ~PackedGemmMatrixFP16() {
+ free(pmat_);
+ }
+
+// protected:
+ // blocked row-major format address arithmetic
+ uint64_t addr(const int r_, const int c_) const {
+ uint64_t r = (uint64_t)r_;
+ uint64_t c = (uint64_t)c_;
+
+ uint64_t block_row_id = r / blockRowSize(),
+ brow_offset =
+ (block_row_id * nbcol_) * (blockRowSize() * blockColSize());
+ uint64_t block_col_id = c / blockColSize(),
+ bcol_offset =
+ block_col_id * ((block_row_id != nbrow_ - 1)
+ ? (blockRowSize() * blockColSize())
+ : (last_brow_ * blockColSize()));
+ uint64_t block_offset = brow_offset + bcol_offset;
+ uint64_t inblock_offset =
+ r % blockRowSize() * blockColSize() + c % blockColSize();
+
+ uint64_t index = block_offset + inblock_offset;
+ assert(index < matSize());
+ return index;
+ }
+
+ void packFromSrc(const matrix_op_t trans, const float alpha,
+ const float *smat) {
+ bool tr = (trans == matrix_op_t::Transpose);
+ // pack
+ for (int i = 0; i < numRows(); i++) {
+ for (int j = 0; j < numCols(); j++) {
+ pmat_[addr(i, j)] = tconv(
+ alpha * (
+ (tr == false)
+ ? smat[i * numCols() + j] : smat[i + numRows() * j]),
+ pmat_[addr(i, j)]);
+ }
+ }
+ }
+
+ const float16 &operator()(const int r, const int c) const {
+ uint64_t a = addr(r, c);
+ assert(r < numRows());
+ assert(c < numCols());
+ assert(a < this->matSize());
+ return pmat_[a];
+ }
+
+ int matSize() const { return size_; }
+ int numRows() const { return nrow_; }
+ int numCols() const { return ncol_; }
+ inline int blockRowSize() const { return brow_; }
+ inline int blockColSize() const { return bcol_; }
+
+ int nrow_, ncol_;
+ int brow_, last_brow_, bcol_;
+ int nbrow_, nbcol_;
+ uint64_t size_;
+ float16 *pmat_;
+
+ friend void cblas_gemm_compute(const matrix_op_t transa, const int m,
+ const float *A,
+ const PackedGemmMatrixFP16 &Bp,
+ const float beta, float *C);
+ friend void cblas_gemm_compute(const matrix_op_t transa, const int m,
+ const float *A,
+ const PackedGemmMatrixFP16 &Bp,
+ const float beta, float *C);
+};
+
+/**
+ * restrictions: transa == CblasNoTrans
+ */
+extern void cblas_gemm_compute(const matrix_op_t transa, const int m,
+ const float *A,
+ const PackedGemmMatrixFP16 &Bp,
+ const float beta, float *C);
+extern void cblas_gemm_compute(const matrix_op_t transa, const int m,
+ const float *A,
+ const PackedGemmMatrixFP16 &Bp,
+ const float beta, float *C);
+
+}; // namespace fbgemm
diff --git a/include/fbgemm/FbgemmI8Spmdm.h b/include/fbgemm/FbgemmI8Spmdm.h
new file mode 100644
index 0000000..264b70e
--- /dev/null
+++ b/include/fbgemm/FbgemmI8Spmdm.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <cstdint>
+#include <vector>
+#include "Utils.h"
+
+// #define FBGEMM_MEASURE_TIME_BREAKDOWN
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+#include <chrono>
+#include <iostream>
+extern double spmdm_initial_time;
+extern double spmdm_transpose_uint8_time;
+extern double spmdm_transpose_32xN_time;
+extern double spmdm_compute_time;
+extern double spmdm_transpose_Nx32_time;
+extern double spmdm_run_time;
+#endif
+
+namespace fbgemm2 {
+
+/**
+ * @brief A class to represent a matrix in Compressed Sparse Column (CSC)
+ * format.
+ *
+ * The second input matrix of matrix multiplication is usually weight and can
+ * be sparse, and it's usually more efficient to use CSC format to represent
+ * the second input matrix.
+ */
+class CompressedSparseColumn {
+ public:
+ CompressedSparseColumn(int num_of_rows, int num_of_cols);
+
+ std::vector<std::int32_t>& ColPtr() {
+ return colptr_;
+ }
+ std::vector<std::int16_t>& RowIdx() {
+ return rowidx_;
+ }
+ std::vector<std::int8_t>& Values() {
+ return values_;
+ }
+
+ std::size_t NumOfRows() const {
+ return num_rows_;
+ }
+ std::size_t NumOfCols() const {
+ return colptr_.size() - 1;
+ }
+ std::int32_t NumOfNonZeros() const {
+ return colptr_.back();
+ }
+
+ /**
+ * @return Total number of non-zero elements as a fraction of total
+ * elements.
+ */
+ double Density() const;
+
+ /**
+ * @return True if the number of non-zeros per row is smaller than a small
+ * threshold.
+ */
+ bool IsHyperSparse() const;
+
+ /**
+ * @brief Perform dense-matrix * sparse matrix.
+ *
+ * C += A (dense matrix) * B (this CSC matrix) if accumulation = true \n
+ * C = A (dense matrix) * B (this CSC matrix) if accumulation = false
+ */
+ void SpMDM(
+ const block_type_t& block,
+ const std::uint8_t* A,
+ int lda,
+ bool accumulation,
+ std::int32_t* C,
+ int ldc) const;
+
+ private:
+ const std::size_t num_rows_;
+ std::vector<std::int32_t> colptr_;
+ std::vector<std::int16_t> rowidx_;
+ std::vector<std::int8_t> values_;
+
+ // Cache IsHyperSparse to minimize its overhead.
+ mutable bool hyper_sparse_;
+
+ // Whether we can reuse the cached hyper_sparse_ is determined by checking
+ // if NumOfNonZeros() is same as old_nnz_ saved in previous invocation of
+ // IsHyperSparse call.
+ mutable std::int32_t old_nnz_;
+};
+
+} // namespace fbgemm2
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
new file mode 100644
index 0000000..13f614a
--- /dev/null
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -0,0 +1,356 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+template <typename outT, typename inT, typename nextOPType>
+template<inst_set_t instSet>
+inline int memCopy<outT, inT, nextOPType>::f(outT* out, inT* inp,
+ const block_type_t& block, int ld_out, int ld_in) const {
+ static_assert(
+ std::is_same<outT, inT>::value,
+ "input and output data type must be of same type");
+ // only copy if destination is not the same as source
+ if (out + block.row_start*ld_out + block.col_start != inp) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ memcpy(out + block.col_start + i * ld_out,
+ inp + (i - block.row_start) * ld_in,
+ block.col_size*sizeof(inT));
+ }
+ }
+ return nextop_.template f<instSet>(out, out, block, ld_out, ld_out);
+}
+
+template <typename outT, typename inT, typename nextOPType>
+template<inst_set_t instSet>
+inline int DoSpmdmOnInpBuffer<outT, inT, nextOPType>::f(outT* out, inT* inp,
+ const block_type_t& block, int ld_out, int ld_in) const {
+ B_csc_.SpMDM(block, A_, lda_, true, inp, ld_in);
+ return nextop_.template f<instSet>(out, inp, block, ld_out, ld_in);
+}
+
+template <bool FUSE_RELU, typename outT, typename inT, typename nextOPType>
+template <inst_set_t instSet>
+inline int ReQuantizeOutput<FUSE_RELU, outT, inT, nextOPType>::f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const {
+ static_assert(
+ std::is_same<inT, int32_t>::value,
+ "input data type must be of int32_t type");
+ if (instSet == inst_set_t::anyarch) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ inT raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)];
+ raw -= Aq_zero_point_ * q_col_offsets_[j];
+ if (q_row_offsets_) {
+ raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_;
+ }
+ if (bias_) {
+ raw += bias_[j];
+ }
+
+ float ab = raw * C_multiplier_;
+ long rounded = std::lrintf(ab) + C_zero_point_;
+
+ out[i * ld_out + j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point_) : 0l,
+ std::min(255l, rounded));
+ }
+ }
+ } else if (instSet == inst_set_t::avx2) {
+ if (std::is_same<outT, uint8_t>::value) {
+ // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
+ // using AVX2 instructions
+ __m256 multiplier_v = _mm256_set1_ps(C_multiplier_);
+
+ __m256i min_v = _mm256_set1_epi8(std::numeric_limits<uint8_t>::min());
+ __m256i max_v = _mm256_set1_epi8(std::numeric_limits<uint8_t>::max());
+
+ __m256i A_zero_point_v = _mm256_set1_epi32(Aq_zero_point_);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point_);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point_);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ std::int32_t row_offset = q_row_offsets_
+ ? q_row_offsets_[i - block.row_start] * Bq_zero_point_
+ : 0;
+ __m256i row_offset_v = _mm256_set1_epi32(row_offset);
+ int j = block.col_start;
+ for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4));
+ j += (VLEN * 4)) {
+ __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start)));
+ __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start) +
+ 1 * VLEN));
+ __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start) +
+ 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start) +
+ 3 * VLEN));
+
+ //if (A_zero_pt != 0) {
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(q_col_offsets_ + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(q_col_offsets_ + j + VLEN)));
+ y_v = _mm256_sub_epi32(y_v, col_off_v);
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ q_col_offsets_ + j + 2 * VLEN)));
+ z_v = _mm256_sub_epi32(z_v, col_off_v);
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ q_col_offsets_ + j + 3 * VLEN)));
+ w_v = _mm256_sub_epi32(w_v, col_off_v);
+ //}
+
+ // if (row_offset != 0) {
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ y_v = _mm256_sub_epi32(y_v, row_offset_v);
+ z_v = _mm256_sub_epi32(z_v, row_offset_v);
+ w_v = _mm256_sub_epi32(w_v, row_offset_v);
+ //}
+ if (bias_) {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias_ + j)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias_ + j + VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias_ + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias_ + j + 3 * VLEN)));
+ }
+
+ /*
+ * Convert int32_t input to FP32 and multiply by FP32 scale.
+ * Both operations involve statistically unbiased roundings (with
+ * default MXCSR rounding mode):
+ * - Large int32_t values can't be exactly represented as FP32.
+ * CVTDQ2PS instruction on x86 would round it according to nearest
+ * FP32 value with ties to even (assuming default MXCSR rounding
+ * mode).
+ * - Product of two FP32 values is generally not exactly
+ * representation as an FP32 value, and will be rounded to nearest
+ * FP32 value with ties to even with default MXCSR rounding mode.
+ */
+ __m256 x_scaled_v =
+ _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ __m256 y_scaled_v =
+ _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ __m256 z_scaled_v =
+ _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ __m256 w_scaled_v =
+ _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+
+ /*
+ * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
+ * CVTPS2DQ instruction rounds result according to nearest FP32 value
+ * with ties to even (assuming default MXCSR rounding mode). However,
+ * when conversion overflows, it produces INT32_MIN as a result. For
+ * large positive inputs the result of conversion can become negative,
+ * which affects the final requantization result. Note that on x86
+ * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
+ * happens because float(INT32_MAX) rounds to 2**31, which overflows
+ * int32_t when it is converted back to integer.
+ *
+ * Thankfully, we can prove that overflow never happens in this
+ * requantization scheme. The largest positive input is INT32_MAX
+ * (2**31 - 1), which turns into 2**31 when converted to float. The
+ * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
+ * result is 2147483520 (compare to INT32_MAX = 2147483647), which
+ * fits into int32_t without overflow.
+ */
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ /*
+ * Standard final sequence on x86 AVX2:
+ * - Pack to int16_t and saturate
+ * - Add zero point
+ * - Pack to uint8_t and saturate
+ * - Clamp between qmin and qmax
+ */
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v),
+ C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v),
+ C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ /*
+ * xyzw_clamped_v has results in the following layout so we need to
+ * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
+ */
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ /*
+ * 4x CVTDQ2PS
+ * 4x MULPS
+ * 4x CVTPS2DQ
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 2x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * 1x PERMD
+ * ---------------------
+ * 20 instructions total
+ */
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for (; j < block.col_start + (block.col_size / VLEN * VLEN);
+ j += VLEN) {
+ __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start)));
+
+ //if (A_zero_pt != 0) {
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(q_col_offsets_ + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ //}
+
+ // if (row_offset != 0) {
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ //}
+ if (bias_) {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias_ + j)));
+ }
+
+ __m256 x_scaled_v =
+ _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ /*
+ * x_clamped_v has results in the following layout so we need to
+ * permute: x0-3 garbage0-11 x4-7 garbage12-23
+ */
+ x_clamped_v =
+ _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ /*
+ * 1x CVTDQ2PS
+ * 1x MULPS
+ * 1x CVTPS2DQ
+ * 1x PACKSSDW
+ * 1x PACKUSWB
+ * 1x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * 1x PERMD
+ * ---------------------
+ * 9 instructions total
+ */
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(out + i * ld_out + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for (; j < block.col_start + block.col_size; ++j) {
+ int32_t raw =
+ inp[(i - block.row_start) * ld_in + (j - block.col_start)];
+ // if (A_zero_pt != 0) {
+ raw -= Aq_zero_point_ * q_col_offsets_[j];
+ //}
+ raw -= row_offset;
+ if (bias_) {
+ raw += bias_[j];
+ }
+
+ float ab = raw * C_multiplier_;
+ long rounded = std::lrintf(ab) + C_zero_point_;
+
+ out[i * ld_out + j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point_) : 0l,
+ std::min(255l, rounded));
+ } // j loop remainder
+ } // i loop
+ } else {
+ assert(0 && "Not supported yet");
+ }
+ } else {
+ assert(0 && "Not supported yet");
+ }
+ return nextop_.template f<instSet>(out, out, block, ld_out, ld_out);
+}
+
+template <bool FUSE_RELU, typename outT, typename inT, typename nextOPType>
+template <inst_set_t instSet>
+inline int ReQuantizeForFloat<FUSE_RELU, outT, inT, nextOPType>::f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const {
+ static_assert(
+ std::is_same<int32_t, inT>::value,
+ "input data type is of not expected type");
+ static_assert(
+ std::is_same<float, outT>::value,
+ "output data type is of not expected type");
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start];
+ raw -= Aq_zero_point_ * q_col_offsets_[j];
+ raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_;
+ float res = raw * Aq_scale_ * Bq_scale_;
+ if (bias_) {
+ res += bias_[j];
+ }
+ out[i * ld_out + j] = res;
+ if (FUSE_RELU) {
+ out[i * ld_out + j] = std::max<outT>(0.0f, out[i * ld_out + j]);
+ }
+ }
+ }
+
+ return nextop_.template f<instSet>(out, out, block, ld_out, ld_out);
+}
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
new file mode 100644
index 0000000..35cd50a
--- /dev/null
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit
+ * integers.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx2.
+ */
+template <typename T>
+struct PackingTraits<
+ T,
+ std::int32_t,
+ inst_set_t::avx2,
+ typename std::enable_if<is_8bit<T>::value>::type> {
+ static constexpr int MR{12}; ///< Register block for M dimension
+ static constexpr int NR{8}; ///< Register block for N dimension
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 120}; ///< cache block for M dimension (multiple of MR)
+ static constexpr int NCB{8}; ///< cache block for N dimension (multiple of NR)
+ static constexpr int KCB{512}; ///< cache block for K dimension
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 16-bit
+ * integers.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx2.
+ */
+template <typename T>
+struct PackingTraits<
+ T,
+ std::int16_t,
+ inst_set_t::avx2,
+ typename std::enable_if<is_8bit<T>::value>::type> {
+ static constexpr int MR{3}; ///< Register block for M dimension
+ static constexpr int NR{16}; ///< Register block for N dimension; Total
+ ///< register used for N dimension: NCB/NR
+
+ static constexpr int ROW_INTERLEAVE{
+ 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 60}; ///< cache block for M dimension (multiple of MR)
+ static constexpr int NCB{
+ 64}; ///< cache block for N dimension (multiple of NR)
+ static constexpr int KCB{256}; ///< cache block for K dimension
+};
+
+/**
+ * @brief Packing parameter specialization for float input and float
+ * accumulation.
+ *
+ * This is picked when template paramtere T is of float type and instruction
+ * set is avx2.
+ */
+template <>
+struct PackingTraits<float, float, inst_set_t::avx2> {
+ static constexpr int MR{3}; ///< Register block for M dimension
+ static constexpr int NR{32}; ///< Register block for N dimension
+
+ static constexpr int ROW_INTERLEAVE{1}; ///< No Row interleave.
+
+ static constexpr int MCB{
+ 24}; ///< cache block for M dimension (multiple of MR)
+ static constexpr int NCB{
+ 64}; ///< cache block for N dimension (multiple of NR)
+ static constexpr int KCB{256}; ///< cache block for K dimension
+};
+
+/**
+ * @brief Packing parameter specialization for fp16 input and float
+ * accumulation.
+ *
+ * This is picked when template parameter T is of float16 type and instruction
+ * set is avx2.
+ */
+template <>
+struct PackingTraits<float16, float, inst_set_t::avx2> {
+ static constexpr int BCOL{8};
+ static constexpr int ROW_INTERLEAVE{1};
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit
+ * integers.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512.
+ */
+template <typename T>
+struct PackingTraits<
+ T,
+ std::int32_t,
+ inst_set_t::avx512,
+ typename std::enable_if<is_8bit<T>::value>::type> {
+ static constexpr int MR{28}; ///< Register block for M dimension
+ static constexpr int NR{16}; ///< Register block for N dimension
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix
+
+ static constexpr int MCB{
+ 140}; ///< cache block for M dimension (multiple of MR)
+ static constexpr int NCB{
+ 16}; ///< cache block for N dimension (multiple of NR)
+ static constexpr int KCB{512}; ///< cache block for K dimension
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 16-bit
+ * integers.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512.
+ */
+template <typename T>
+struct PackingTraits<
+ T,
+ std::int16_t,
+ inst_set_t::avx512,
+ typename std::enable_if<is_8bit<T>::value>::type> {
+ static constexpr int MR{6}; ///< Register block for M dimension
+ static constexpr int NR{32}; ///< Register block for N dimension; Total
+ ///< register used for N dimension: NCB/NR
+
+ static constexpr int ROW_INTERLEAVE{
+ 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 60}; ///< cache block for M dimension (multiple of MR)
+ static constexpr int NCB{
+ 128}; ///< cache block for N dimension (multiple of NR)
+ static constexpr int KCB{256}; ///< cache block for K dimension
+};
diff --git a/include/fbgemm/Types.h b/include/fbgemm/Types.h
new file mode 100644
index 0000000..c5c62dd
--- /dev/null
+++ b/include/fbgemm/Types.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+namespace fbgemm2 {
+
+typedef struct __attribute__((aligned(2))) __f16 {
+ uint16_t x;
+} float16;
+
+static inline float16 cpu_float2half_rn(float f) {
+ float16 ret;
+
+ static_assert(sizeof(unsigned int) == sizeof(float),
+ "Programming error sizeof(unsigned int) != sizeof(float)");
+
+ unsigned *xp = reinterpret_cast<unsigned int *>(&f);
+ unsigned x = *xp;
+ unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
+ unsigned sign, exponent, mantissa;
+
+ // Get rid of +NaN/-NaN case first.
+ if (u > 0x7f800000) {
+ ret.x = 0x7fffU;
+ return ret;
+ }
+
+ sign = ((x >> 16) & 0x8000);
+
+ // Get rid of +Inf/-Inf, +0/-0.
+ if (u > 0x477fefff) {
+ ret.x = sign | 0x7c00U;
+ return ret;
+ }
+ if (u < 0x33000001) {
+ ret.x = (sign | 0x0000);
+ return ret;
+ }
+
+ exponent = ((u >> 23) & 0xff);
+ mantissa = (u & 0x7fffff);
+
+ if (exponent > 0x70) {
+ shift = 13;
+ exponent -= 0x70;
+ } else {
+ shift = 0x7e - exponent;
+ exponent = 0;
+ mantissa |= 0x800000;
+ }
+ lsb = (1 << shift);
+ lsb_s1 = (lsb >> 1);
+ lsb_m1 = (lsb - 1);
+
+ // Round to nearest even.
+ remainder = (mantissa & lsb_m1);
+ mantissa >>= shift;
+ if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
+ ++mantissa;
+ if (!(mantissa & 0x3ff)) {
+ ++exponent;
+ mantissa = 0;
+ }
+ }
+
+ ret.x = (sign | (exponent << 10) | mantissa);
+
+ return ret;
+}
+
+static inline float cpu_half2float(float16 h) {
+ unsigned sign = ((h.x >> 15) & 1);
+ unsigned exponent = ((h.x >> 10) & 0x1f);
+ unsigned mantissa = ((h.x & 0x3ff) << 13);
+
+ if (exponent == 0x1f) { /* NaN or Inf */
+ mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
+ exponent = 0xff;
+ } else if (!exponent) { /* Denorm or Zero */
+ if (mantissa) {
+ unsigned int msb;
+ exponent = 0x71;
+ do {
+ msb = (mantissa & 0x400000);
+ mantissa <<= 1; /* normalize */
+ --exponent;
+ } while (!msb);
+ mantissa &= 0x7fffff; /* 1.mantissa is implicit */
+ }
+ } else {
+ exponent += 0x70;
+ }
+
+ unsigned i = ((sign << 31) | (exponent << 23) | mantissa);
+ float ret;
+ memcpy(&ret, &i, sizeof(i));
+ return ret;
+}
+
+static inline uint8_t tconv(float x, int8_t /*rtype*/) { return int8_t(x); }
+static inline uint8_t tconv(float x, uint8_t /*rtype*/) { return uint8_t(x); }
+static inline float16 tconv(float x, float16 /*rtype*/) {
+ return cpu_float2half_rn(x);
+}
+
+template <typename T> T tconv(T x, T /*rtype*/) { return x; }
+}
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
new file mode 100644
index 0000000..22e5a16
--- /dev/null
+++ b/include/fbgemm/Utils.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <string>
+#include <type_traits>
+
+namespace fbgemm2 {
+
+/**
+ * @brief Helper struct to type specialize for uint8 and int8 together.
+ */
+template <typename T>
+struct is_8bit {
+ static constexpr bool value =
+ std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
+};
+
+/**
+ * @brief Typed enum to specify matrix operations.
+ */
+enum class matrix_op_t { NoTranspose, Transpose };
+
+/**
+ * @brief Typed enum for supported instruction sets.
+ */
+enum class inst_set_t { anyarch, avx2, avx512 };
+
+/**
+ * @brief Typed enum for implementation type.
+ *
+ * ref is reference and opt is optimized.
+ */
+enum class impl_type_t { ref, opt };
+
+
+/**
+ * @brief A struct to represent a block of a matrix.
+ */
+struct block_type_t {
+ int row_start;
+ int row_size;
+ int col_start;
+ int col_size;
+
+ std::string toString() const {
+ std::string out = "";
+ out += "row start:" + std::to_string(row_start) + ", ";
+ out += "row size:" + std::to_string(row_size) + ", ";
+ out += "col start:" + std::to_string(col_start) + ", ";
+ out += "col size:" + std::to_string(col_size);
+ return out;
+ }
+};
+
+/**
+ * @brief A function to compare data in two buffers for closeness/equality.
+ */
+template <typename T>
+int compare_buffers(
+ const T* ref,
+ const T* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol = 1e-3);
+
+/**
+ * @brief Debugging helper.
+ */
+template <typename T>
+void printMatrix(
+ matrix_op_t trans,
+ const T* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+
+/**
+ * @brief Top-level routine to transpose a matrix.
+ *
+ * This calls transpose_8x8 or transpose_16x16 internally.
+ */
+void transpose_simd(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst);
+
+/**
+ * @brief Transpose a matrix using Intel AVX2.
+ *
+ * This is called if the code is running on a CPU with Intel AVX2 support.
+ */
+void transpose_8x8(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst);
+
+/**
+ * @brief Transpose a matrix using Intel AVX512.
+ *
+ * This is called if the code is running on a CPU with Intel AVX512 support.
+ */
+void transpose_16x16(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst);
+
+} // namespace fbgemm2
diff --git a/src/ExecuteKernel.cc b/src/ExecuteKernel.cc
new file mode 100644
index 0000000..0e3d122
--- /dev/null
+++ b/src/ExecuteKernel.cc
@@ -0,0 +1,12 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "ExecuteKernel.h"
+#include <immintrin.h>
+#include "fbgemm/Fbgemm.h"
+#include "fbgemm/Utils.h"
+
+namespace fbgemm2 {} // namespace fbgemm2
diff --git a/src/ExecuteKernel.h b/src/ExecuteKernel.h
new file mode 100644
index 0000000..55a2581
--- /dev/null
+++ b/src/ExecuteKernel.h
@@ -0,0 +1,11 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cstdint>
+#include "fbgemm/Fbgemm.h"
+#include "ExecuteKernelGeneric.h"
+#include "ExecuteKernelU8S8.h"
diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h
new file mode 100644
index 0000000..e83e943
--- /dev/null
+++ b/src/ExecuteKernelGeneric.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cstdint>
+#include "fbgemm/Fbgemm.h"
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Execute Engine for the macro-kernel and output processing.
+ * ExecuteKernel is a derived class of CodeGenBase.
+ */
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename cT,
+ typename processOutputType>
+class ExecuteKernel : public CodeGenBase<
+ typename packingAMatrix::inpType,
+ typename packingBMatrix::inpType,
+ cT,
+ typename packingBMatrix::accType> {
+ public:
+ ExecuteKernel(
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>& packA,
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ typename packingBMatrix::accType* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess);
+ void execute(int kBlock);
+
+ private:
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>&
+ packedA_; ///< Packed block of matrix A.
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packedB_; ///< Packed matrix B.
+ int32_t kBlock_; ///< Block ID in the k dimension.
+ cT* matC_; ///< Output for matrix C.
+ typename packingAMatrix::accType*
+ C_buffer_; ///< the accumulation buffer for matrix C.
+ int32_t ldc_; ///< the leading dimension of matrix C.
+ const processOutputType& outputProcess_; ///< output processing function for
+ ///< the C tile in the macro-kernel.
+};
+
+} // namespace fbgemm2
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
new file mode 100644
index 0000000..5145869
--- /dev/null
+++ b/src/ExecuteKernelU8S8.cc
@@ -0,0 +1,354 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "ExecuteKernelU8S8.h"
+#include <cpuinfo.h>
+#include <chrono>
+
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double kernel_time = 0.0;
+double postprocessing_time = 0.0;
+#endif
+
+namespace fbgemm2 {
+
+template <typename packingAMatrix, typename cT, typename processOutputType>
+ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>::
+ ExecuteKernel(
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packA,
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ int32_t* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess)
+ : packedA_(packA),
+ packedB_(packB),
+ kBlock_(kBlock),
+ matC_(matC),
+ C_buffer_(C_buffer),
+ ldc_(ldc),
+ outputProcess_(outputProcess) {
+ if (cpuinfo_has_x86_avx512f()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::NCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::NCB;
+ } else {
+ assert(0 && "unsupported architecure");
+ }
+ C_tile_ = new int32_t[mbSize_ * nbSize_];
+}
+
+template <typename packingAMatrix, typename cT, typename processOutputType>
+void ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>::execute(int kBlock) {
+ // packedA_.printPackedMatrix("packedA from kernel");
+ // packedB_.printPackedMatrix("packedB from kernel");
+
+ int32_t bColBlocks = packedB_.blockCols();
+
+ int8_t* bBuf;
+ int8_t* bBuf_pf;
+
+ uint8_t* aBuf = packedA_.getBuf(0);
+
+ int32_t packed_rows_A = packedA_.numPackedRows();
+ int32_t row_start_A = packedA_.packedRowStart();
+
+ bool lastKBlock = packedB_.isThisLastKBlock(kBlock);
+ bool accum = kBlock > 0;
+
+ typename BaseType::jit_micro_kernel_fp fn;
+
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else if (cpuinfo_has_x86_avx2()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx2>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ for (int jb = 0; jb < bColBlocks; ++jb) {
+
+ bBuf = packedB_.getBuf(jb, kBlock);
+ // prefetch addr of the next packed block of B matrix
+ bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock);
+
+ // Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as
+ // matC_ (inplace output processing)
+ int32_t* C_buffer_row_start = C_buffer_ +
+ ((C_buffer_ == reinterpret_cast<int32_t*>(matC_)) ? row_start_A * ldc_
+ : 0);
+ int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_;
+ int32_t leadingDim = ldc_;
+ if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) {
+ // In case we will access memory past C_buffer_, we use C_tile_ instead.
+ C_buffer_start = C_tile_;
+ leadingDim = nbSize_;
+ }
+
+ fn(aBuf,
+ bBuf,
+ bBuf_pf,
+ C_buffer_start,
+ packedA_.numPackedCols(),
+ leadingDim);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ kernel_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Output processing is done only once per rowblock
+ if (lastKBlock && jb == bColBlocks - 1) {
+ // When C_tile_ is used for the last column block, we need a separate
+ // handling for the last column block.
+ int32_t nSize =
+ C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
+ if (nSize) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_buffer_row_start,
+ {row_start_A, packed_rows_A, 0, nSize},
+ ldc_,
+ ldc_);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_buffer_row_start,
+ {row_start_A, packed_rows_A, 0, nSize},
+ ldc_,
+ ldc_);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ }
+
+ if (C_buffer_start == C_tile_) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_tile_,
+ {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ ldc_,
+ leadingDim);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_tile_,
+ {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ ldc_,
+ leadingDim);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ }
+ } // output processing
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ postprocessing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ } // for each j block
+}
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ uint8_t,
+ ReQuantizeOutput<false /* FUSE_RELU*/>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ uint8_t,
+ ReQuantizeOutput<true>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<true>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<true>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<false>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<false>::outType,
+ int32_t,
+ ReQuantizeOutput<false>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<true>::outType,
+ int32_t,
+ ReQuantizeOutput<true>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ float,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeForFloat<false>::outType,
+ int32_t,
+ ReQuantizeForFloat<false>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<false>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<true>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ DoNothing<int32_t, int32_t>>;
+
+} // namespace fbgemm2
diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h
new file mode 100644
index 0000000..0bd7fc5
--- /dev/null
+++ b/src/ExecuteKernelU8S8.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include "ExecuteKernel.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Execute Engine of uint 8 and int8 matrix
+ * multiplication for the macro-kernel and output processing. ExecuteKernel is a
+ * derived class of CodeGenBase.
+ */
+template <typename packingAMatrix, typename cT, typename processOutputType>
+class ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>
+ : public CodeGenBase<
+ uint8_t,
+ int8_t,
+ int32_t,
+ typename packingAMatrix::accType> {
+ public:
+ using BaseType =
+ CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>;
+ /**
+ * @brief Constructor for initializing the parameters for macro-kernel and
+ * output processing type.
+ */
+ ExecuteKernel(
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packA,
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ int32_t* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess);
+ void execute(int kBlock);
+
+ ~ExecuteKernel() {
+ delete[] C_tile_;
+ }
+
+ private:
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packedA_; ///< Packed uint8 block of matrix A.
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>&
+ packedB_; ///< Packed int8 matrix B.
+ int32_t kBlock_; ///< Block ID in the k dimension.
+ cT* matC_; ///< Output for matrix C.
+ int32_t* C_buffer_; ///< the accumulation buffer for matrix C.
+ int32_t ldc_; ///< the leading dimension of matrix C.
+ const processOutputType& outputProcess_; ///< output processing function for
+ ///< matrix C in the macro-kernel.
+ int32_t* C_tile_; ///< buffer for the last N block when NCB is not an exact
+ ///< multiple of N.
+ int mbSize_; ///< block size in the m dimension.
+ int nbSize_; ///< block size in the n dimension.
+};
+
+} // namespace fbgemm2
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
new file mode 100644
index 0000000..f3bac97
--- /dev/null
+++ b/src/Fbgemm.cc
@@ -0,0 +1,363 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/Fbgemm.h"
+#include <cpuinfo.h>
+#include <stdexcept>
+#include "ExecuteKernel.h"
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double packing_time = 0.0;
+double computing_time = 0.0;
+double run_time = 0.0;
+#endif
+
+using namespace fbgemm2;
+
+namespace fbgemm2 {
+
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename cT,
+ typename processOutputType>
+void fbgemmPacked(
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>& packA,
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packB,
+ cT* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const processOutputType& outProcess,
+ int thread_id,
+ int /* num_threads */) {
+ static_assert(
+ std::is_same<
+ typename packingAMatrix::accType,
+ typename packingBMatrix::accType>::value,
+ "Accumulation type of both matrices should be the same");
+
+ int MCB, KCB;
+
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::KCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::KCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+
+ int MDim = packA.numRows();
+ int KDim = packB.numRows();
+
+ int mBlocks = (MDim + MCB - 1) / MCB;
+ int kBlocks = (KDim + KCB - 1) / KCB;
+
+ // remainders
+ int _mc = MDim % MCB;
+ int _kc = KDim % KCB;
+
+ int kc, mc;
+
+ block_type_t blockA{0, 0, 0, 0};
+
+ // B must be prepacked
+ assert(packB.isPrePacked() && "B matrix must be prepacked");
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
+ t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+ t_very_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType>
+ exeKernelObj(packA, packB, 0, C, C_buffer, ldc, outProcess);
+ // ToDo: thread based work division
+ for (int i = 0; i < mBlocks; ++i) {
+ mc = (i != mBlocks - 1 || _mc == 0) ? MCB : _mc;
+ for (int k = 0; k < kBlocks; ++k) {
+ kc = (k != kBlocks - 1 || _kc == 0) ? KCB : _kc;
+ // pack A matrix
+ blockA = {i * MCB, mc, k * KCB, kc};
+
+ packA.pack(blockA);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ packing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ exeKernelObj.execute(k);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ computing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+ }
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt =
+ std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
+ .count();
+ run_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+}
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+// 16 bit accumulation functions
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<false>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<true>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoNothing<int32_t, int32_t>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm2
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
new file mode 100644
index 0000000..7bbfa54
--- /dev/null
+++ b/src/FbgemmFP16.cc
@@ -0,0 +1,293 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmFP16.h"
+
+#include <cpuinfo.h>
+
+#include "FbgemmFP16UKernels.h"
+
+using namespace std;
+
+namespace fbgemm2 {
+
+/// class that performs packing of matrix in
+/// row-major or col-major format into
+/// internal packed blocked-row major format
+
+/// Todo: make it fast with AVX2 transpose
+inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
+ // for (int r = 0; r < nrow; ++r) {
+ // for (int c = 0; c < ncol; ++c) {
+ // to[r + c * nrow] = from[r * ldim + c];
+ // }
+ // }
+ transpose_simd( nrow, ncol, from, ldim, to, nrow );
+}
+
+struct KernelInfo {
+ using knl_ptr = funcptr_fp16;
+ // optimized kernels to cover all cases
+ static constexpr array<knl_ptr, 15> kernel = {{
+ nullptr, gemmkernel_1x1_AVX2_fA0fB0fC0,
+ gemmkernel_2x1_AVX2_fA0fB0fC0, gemmkernel_3x1_AVX2_fA0fB0fC0,
+ gemmkernel_4x1_AVX2_fA0fB0fC0, gemmkernel_5x1_AVX2_fA0fB0fC0,
+ gemmkernel_6x1_AVX2_fA0fB0fC0, gemmkernel_7x1_AVX2_fA0fB0fC0,
+ gemmkernel_8x1_AVX2_fA0fB0fC0, gemmkernel_9x1_AVX2_fA0fB0fC0,
+ gemmkernel_10x1_AVX2_fA0fB0fC0, gemmkernel_11x1_AVX2_fA0fB0fC0,
+ gemmkernel_12x1_AVX2_fA0fB0fC0, gemmkernel_13x1_AVX2_fA0fB0fC0,
+ gemmkernel_14x1_AVX2_fA0fB0fC0
+ }};
+
+ // autotuned kernel splits for various cases m = 1:mb_max
+ // may need re-autotuning for new uarch
+ static constexpr array<array<pair<int, int>, 2>, 121 > partition = {
+ {
+ {{ { 0, 0 }, { 0, 0 } } },
+ {{ { 1, 1 }, { 0, 0 } } },
+ {{ { 2, 1 }, { 0, 0 } } },
+ {{ { 3, 1 }, { 0, 0 } } },
+ {{ { 4, 1 }, { 0, 0 } } },
+ {{ { 5, 1 }, { 0, 0 } } },
+ {{ { 6, 1 }, { 0, 0 } } },
+ {{ { 7, 1 }, { 0, 0 } } },
+ {{ { 8, 1 }, { 0, 0 } } },
+ {{ { 9, 1 }, { 0, 0 } } },
+ {{ { 10, 1 }, { 0, 0 } } },
+ {{ { 11, 1 }, { 0, 0 } } },
+ {{ { 12, 1 }, { 0, 0 } } },
+ {{ { 13, 1 }, { 0, 0 } } },
+ {{ { 14, 1 }, { 0, 0 } } },
+ {{ { 8, 1 }, { 7, 1 } } },
+ {{ { 10, 1 }, { 6, 1 } } },
+ {{ { 11, 1 }, { 6, 1 } } },
+ {{ { 12, 1 }, { 6, 1 } } },
+ {{ { 11, 1 }, { 8, 1 } } },
+ {{ { 11, 1 }, { 9, 1 } } },
+ {{ { 12, 1 }, { 9, 1 } } },
+ {{ { 11, 2 }, { 0, 0 } } },
+ {{ { 12, 1 }, { 11, 1 } } },
+ {{ { 12, 2 }, { 0, 0 } } },
+ {{ { 13, 1 }, { 12, 1 } } },
+ {{ { 13, 2 }, { 0, 0 } } },
+ {{ { 14, 1 }, { 13, 1 } } },
+ {{ { 14, 2 }, { 0, 0 } } },
+ {{ { 11, 2 }, { 7, 1 } } },
+ {{ { 10, 3 }, { 0, 0 } } },
+ {{ { 12, 2 }, { 7, 1 } } },
+ {{ { 12, 2 }, { 8, 1 } } },
+ {{ { 11, 3 }, { 0, 0 } } },
+ {{ { 13, 2 }, { 8, 1 } } },
+ {{ { 13, 2 }, { 9, 1 } } },
+ {{ { 13, 2 }, { 10, 1 } } },
+ {{ { 13, 2 }, { 11, 1 } } },
+ {{ { 13, 2 }, { 12, 1 } } },
+ {{ { 13, 3 }, { 0, 0 } } },
+ {{ { 14, 2 }, { 12, 1 } } },
+ {{ { 14, 2 }, { 13, 1 } } },
+ {{ { 11, 3 }, { 9, 1 } } },
+ {{ { 11, 3 }, { 10, 1 } } },
+ {{ { 11, 4 }, { 0, 0 } } },
+ {{ { 12, 3 }, { 9, 1 } } },
+ {{ { 12, 3 }, { 10, 1 } } },
+ {{ { 13, 3 }, { 8, 1 } } },
+ {{ { 13, 3 }, { 9, 1 } } },
+ {{ { 13, 3 }, { 10, 1 } } },
+ {{ { 13, 3 }, { 11, 1 } } },
+ {{ { 13, 3 }, { 12, 1 } } },
+ {{ { 13, 4 }, { 0, 0 } } },
+ {{ { 14, 3 }, { 11, 1 } } },
+ {{ { 11, 4 }, { 10, 1 } } },
+ {{ { 12, 4 }, { 7, 1 } } },
+ {{ { 14, 4 }, { 0, 0 } } },
+ {{ { 12, 4 }, { 9, 1 } } },
+ {{ { 12, 4 }, { 10, 1 } } },
+ {{ { 12, 4 }, { 11, 1 } } },
+ {{ { 13, 4 }, { 8, 1 } } },
+ {{ { 13, 4 }, { 9, 1 } } },
+ {{ { 13, 4 }, { 10, 1 } } },
+ {{ { 13, 4 }, { 11, 1 } } },
+ {{ { 11, 5 }, { 9, 1 } } },
+ {{ { 13, 5 }, { 0, 0 } } },
+ {{ { 14, 4 }, { 10, 1 } } },
+ {{ { 12, 5 }, { 7, 1 } } },
+ {{ { 12, 5 }, { 8, 1 } } },
+ {{ { 14, 4 }, { 13, 1 } } },
+ {{ { 14, 5 }, { 0, 0 } } },
+ {{ { 12, 5 }, { 11, 1 } } },
+ {{ { 13, 5 }, { 7, 1 } } },
+ {{ { 11, 6 }, { 7, 1 } } },
+ {{ { 13, 5 }, { 9, 1 } } },
+ {{ { 13, 5 }, { 10, 1 } } },
+ {{ { 13, 5 }, { 11, 1 } } },
+ {{ { 13, 5 }, { 12, 1 } } },
+ {{ { 13, 6 }, { 0, 0 } } },
+ {{ { 12, 6 }, { 7, 1 } } },
+ {{ { 12, 6 }, { 8, 1 } } },
+ {{ { 12, 6 }, { 9, 1 } } },
+ {{ { 12, 6 }, { 10, 1 } } },
+ {{ { 12, 6 }, { 11, 1 } } },
+ {{ { 12, 7 }, { 0, 0 } } },
+ {{ { 13, 6 }, { 7, 1 } } },
+ {{ { 13, 6 }, { 8, 1 } } },
+ {{ { 13, 6 }, { 9, 1 } } },
+ {{ { 13, 6 }, { 10, 1 } } },
+ {{ { 13, 6 }, { 11, 1 } } },
+ {{ { 13, 6 }, { 12, 1 } } },
+ {{ { 13, 7 }, { 0, 0 } } },
+ {{ { 12, 7 }, { 8, 1 } } },
+ {{ { 12, 7 }, { 9, 1 } } },
+ {{ { 14, 6 }, { 10, 1 } } },
+ {{ { 12, 7 }, { 11, 1 } } },
+ {{ { 13, 7 }, { 5, 1 } } },
+ {{ { 13, 7 }, { 6, 1 } } },
+ {{ { 13, 7 }, { 7, 1 } } },
+ {{ { 13, 7 }, { 8, 1 } } },
+ {{ { 13, 7 }, { 9, 1 } } },
+ {{ { 13, 7 }, { 10, 1 } } },
+ {{ { 13, 7 }, { 11, 1 } } },
+ {{ { 13, 7 }, { 12, 1 } } },
+ {{ { 12, 8 }, { 8, 1 } } },
+ {{ { 12, 8 }, { 9, 1 } } },
+ {{ { 12, 8 }, { 10, 1 } } },
+ {{ { 12, 8 }, { 11, 1 } } },
+ {{ { 12, 9 }, { 0, 0 } } },
+ {{ { 11, 9 }, { 10, 1 } } },
+ {{ { 13, 8 }, { 6, 1 } } },
+ {{ { 13, 8 }, { 7, 1 } } },
+ {{ { 13, 8 }, { 8, 1 } } },
+ {{ { 13, 8 }, { 9, 1 } } },
+ {{ { 13, 8 }, { 10, 1 } } },
+ {{ { 13, 8 }, { 11, 1 } } },
+ {{ { 12, 9 }, { 8, 1 } } },
+ {{ { 13, 9 }, { 0, 0 } } },
+ {{ { 12, 9 }, { 10, 1 } } },
+ {{ { 12, 9 }, { 11, 1 } } },
+ {{ { 12, 10 }, { 0, 0 } } }
+ }
+ };
+};
+constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
+constexpr array<array<pair<int, int>, 2>, 121 > KernelInfo::partition;
+
+// autotuned kernel splits for various cases m = 1:mb_max
+void
+cblas_gemm_compute(const matrix_op_t transa, const int m, const float *A,
+ const PackedGemmMatrixFP16 &Bp, const float beta,
+ float *C) {
+ // ground truth
+ assert(cpuinfo_initialize());
+ assert(cpuinfo_has_x86_fma3());
+ assert(cpuinfo_has_x86_f16c());
+ assert(transa == matrix_op_t::NoTranspose);
+
+ // constants
+ const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
+ const int mb_max = 120;
+ constexpr int simd_width = 8;
+ constexpr int kernel_ncol_blocks = 1;
+ constexpr int kernel_ncols = kernel_ncol_blocks * simd_width;
+
+ // private scratchpad storage
+ static thread_local unique_ptr<std::array<float, 256 * 1024> > scratchpad(
+ new std::array<float, 256 * 1024>());
+
+ GemmParams gp;
+ for (auto m0 = 0; m0 < m; m0 += mb_max) {
+ int mb = std::min(mb_max, m - m0);
+ assert(mb < KernelInfo::partition.size());
+ for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
+
+ // set up proper accumulation to avoid "Nan" problem
+ float beta_;
+ uint64_t accum;
+ if (k_ind == 0) {
+ // accumulate of beta != 0.0
+ // do not!!! accumulate otherwise
+ beta_ = beta;
+ accum = (beta_ == 0.0f) ? 0 : 1;
+ } else {
+ // always accumulate with beta_ = 1.0f
+ beta_ = 1.0f;
+ accum = 1;
+ }
+
+ const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind);
+
+ auto m1 = 0;
+ for (auto c = 0; c < 2; c++) {
+
+ auto kernel_nrows = KernelInfo::partition[mb][c].first;
+ auto nkernel_nrows = KernelInfo::partition[mb][c].second;
+
+ auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
+ for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
+ assert(kernel_nrows * kb < scratchpad->size());
+ PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+
+ int nbcol = n / Bp.blockColSize();
+ gp.k = kb;
+ gp.A = scratchpad->data();
+ gp.B = &(Bp(k_ind, 0));
+ gp.beta = &beta_;
+ gp.accum = accum;
+ gp.C = &C[m2 * ldc];
+ gp.ldc = ldc * sizeof(C[0]);
+ gp.b_block_cols = nbcol;
+ gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
+ if ((n % Bp.blockColSize()) == 0) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ } else {
+ int last_blk_col = nbcol * Bp.blockColSize();
+ if (nbcol) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ }
+
+ // leftover
+ int rem = n - last_blk_col;
+ assert(rem < kernel_ncols);
+ int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
+ : (rem / simd_width);
+ assert(b == 1);
+ if ((rem % simd_width) == 0) {
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = &C[m2 * ldc + last_blk_col];
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ } else {
+ // small temporary buffer
+ float c_tmp[16 * 24] = { 0 };
+ assert((16 * 24) > kernel_nrows * kernel_ncols);
+
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = c_tmp;
+ gp.ldc = 8 * sizeof(C[0]);
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ for (int i = 0; i < kernel_nrows; i++) {
+ // Todo: use assembly
+ for (int j = last_blk_col; j < n; j++) {
+ assert(i * 8 + (j - last_blk_col) <
+ sizeof(c_tmp) / sizeof(c_tmp[0]));
+ if (accum == 0) {
+ C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
+ } else {
+ C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
+ c_tmp[i * 8 + (j - last_blk_col)];
+ }
+ }
+ }
+ }
+ }
+ }
+ m1 += kernel_nrows * nkernel_nrows;
+ }
+ }
+ }
+}
+
+
+} // namespace fbgemm
diff --git a/src/FbgemmFP16UKernels.cc b/src/FbgemmFP16UKernels.cc
new file mode 100644
index 0000000..ec1b297
--- /dev/null
+++ b/src/FbgemmFP16UKernels.cc
@@ -0,0 +1,2203 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "FbgemmFP16UKernels.h"
+
+namespace fbgemm2 {
+
+void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm1,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm1\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm1,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm1\t\n"
+"add r11, 32\t\n"
+"add r9,8\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm2\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm2\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm2\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm2\t\n"
+"add r11, 32\t\n"
+"add r9,16\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm3\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm3\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm3\t\n"
+"add r9,24\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm4\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm4\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm4\t\n"
+"add r9,32\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm5\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm5\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm5\t\n"
+"add r9,40\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm6\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm6\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm6\t\n"
+"add r9,48\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm7\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm7\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm7\t\n"
+"add r9,56\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm8\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm8\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm8\t\n"
+"add r9,64\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm9\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm9\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm9\t\n"
+"add r9,72\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm10\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm10\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm10\t\n"
+"add r9,80\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm11\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm11\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm11\t\n"
+"add r9,88\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm12\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm12\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+88]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+92]\t\n"
+"vfmadd231ps ymm11,ymm14,ymm12\t\n"
+"add r9,96\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+"vxorps ymm12,ymm12,ymm12\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm12,ymm15,ymm13\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm13\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+88]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+92]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+96]\t\n"
+"vfmadd231ps ymm11,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+100]\t\n"
+"vfmadd231ps ymm12,ymm14,ymm13\t\n"
+"add r9,104\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+"vxorps ymm12,ymm12,ymm12\t\n"
+"vxorps ymm13,ymm13,ymm13\t\n"
+
+"mov r11, 0\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm12,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm13,ymm15,ymm14\t\n"
+"add r9,56\t\n"
+"add r11, 16\t\n"
+"inc r14\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+
+} // namespace fbgemm2
diff --git a/src/FbgemmFP16UKernels.h b/src/FbgemmFP16UKernels.h
new file mode 100644
index 0000000..bf7f247
--- /dev/null
+++ b/src/FbgemmFP16UKernels.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#ifndef FBGEMM_UKERNELS
+#define FBGEMM_UKERNELS
+#include <cstdint>
+#include <tuple>
+#include <vector>
+#include "fbgemm/Types.h"
+
+namespace fbgemm2 {
+
+using fp16 = float16;
+using fp32 = float;
+struct GemmParams {uint64_t k; float *A; const fp16 *B;
+float *beta; uint64_t accum; float *C; uint64_t ldc;
+uint64_t b_block_cols; uint64_t b_block_size;};
+void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp);
+typedef void (* funcptr_fp16) (GemmParams *gp);
+;
+
+} // namespace fbgemm2
+
+#endif
diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc
new file mode 100644
index 0000000..54e2272
--- /dev/null
+++ b/src/FbgemmI8Depthwise.cc
@@ -0,0 +1,1953 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "FbgemmI8Depthwise.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <tuple>
+#include <vector>
+
+#include <x86intrin.h>
+
+using namespace std;
+
+namespace fbgemm2
+{
+
+static array<array<int, 8>, 8> masks = {{
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+}};
+
+template <int KERNEL_PROD>
+PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
+ int K, const int8_t *smat)
+ : K_(K) {
+ // Transpose the input matrix to make packing faster.
+ vector<int8_t> smat_transposed(K * KERNEL_PROD);
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ for (int j = 0; j < K; ++j) {
+ smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
+ }
+ }
+
+ // Allocate packed arrays
+ constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
+ pmat_ = static_cast<int8_t *>(aligned_alloc(
+ 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+
+ // Pack input matrix
+ // The layout is optimized to use vpmaddubsw efficiently (see
+ // madd_epi16x4_packed function).
+ // For a group of 32 channels, we have 10 32B SIMD registers.
+ // Denote ith channel jth filter as (i, j)
+ // 0th SIMD register:
+ // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
+ // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
+ // 1st SIMD register:
+ // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
+ // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
+ // 2nd SIMD register:
+ // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
+ // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
+ // 3rd SIMD register:
+ // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
+ // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
+ // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
+ // coefficients
+ // ...
+ //
+ // REMAINDER
+ // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
+ // 8th SIMD register:
+ // (0, 8), zero, ..., (7, 8), zero
+ // (16, 8), zero, ..., (23, 8), zero
+ // 9th SIMD register:
+ // (8, 8), zero, ..., (15, 8), zero
+ // (24, 8), zero, ..., (31, 8), zero
+ // We use madd_epi16_packed for this case
+ //
+ // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
+ // 8th SIMD register:
+ // (0, 8), (0, 9), ..., (7, 8), (7, 9)
+ // (16, 8), (16, 9), ..., (23, 8), (23, 9)
+ // 9th SIMD register:
+ // (8, 8), (8, 9), ..., (15, 8), (15, 9)
+ // (24, 8), (24, 9), ..., (31, 8), (31, 9)
+ //
+ // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
+ // 8th SIMD register:
+ // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
+ // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
+ // 9th SIMD register:
+ // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
+ // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
+ // 10th SIMD register:
+ // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
+ // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
+ // 11th SIMD register:
+ // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
+ // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
+ for (int k1 = 0; k1 < K; k1 += 32) {
+ array<__m256i, KERNEL_PROD> b_v;
+ int remainder = K - k1;
+ if (remainder < 32) {
+ __m256i mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ b_v[i] = _mm256_maskload_epi32(
+ reinterpret_cast<const int *>(smat_transposed.data() + i * K + k1),
+ mask_v);
+ }
+ } else {
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(
+ smat_transposed.data() + i * K + k1));
+ }
+ }
+
+ // Interleave 2 SIMD registers
+ array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi16;
+ __m256i zero_v = _mm256_setzero_si256();
+ for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
+ if (2 * i + 1 >= KERNEL_PROD) {
+ b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
+ } else {
+ b_interleaved_epi16[2 * i] =
+ _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ }
+ }
+
+ // Interleave 4 SIMD registers
+ array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi32;
+ for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
+ b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ }
+ for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
+ b_interleaved_epi32[i] = b_interleaved_epi16[i];
+ }
+
+ for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(
+ &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
+ b_interleaved_epi32[i]);
+ }
+ }
+}
+
+template <int KERNEL_PROD>
+PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix()
+{
+ free(pmat_);
+}
+
+template class PackedDepthWiseConvMatrix<3 * 3>;
+template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline))
+void madd_epi16x4_packed(
+ __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v,
+ const __m256i* b,
+ __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1 + a2 * b2
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline))
+void madd_epi16x3_packed(
+ __m256i a0_v, __m256i a1_v, __m256i a2_v,
+ const __m256i* b,
+ __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void
+madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
+ __m256i *c1_v, __m256i *c2_v, __m256i *c3_v,
+ __m256i *a_sum = nullptr) {
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// c = a0 * b0
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void
+madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
+ __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// K is the number of accumulations we're doing
+template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
+static inline __attribute__((always_inline)) void
+inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
+ int remainder, __m256i *a_sum = nullptr) {
+ array<__m256i, 4> c, c_temp;
+ array<__m256i, 2> a_sum_temp{};
+
+ int k = 0;
+ if (K >= 4) {
+ madd_epi16x4_packed<SUM_A>(a_v[0], a_v[1], a_v[2], a_v[3], Bp,
+ &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
+
+ for (k = 4; k < K / 4 * 4; k += 4) {
+ madd_epi16x4_packed<SUM_A>(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3],
+ Bp + k, &c_temp[0], &c_temp[1], &c_temp[2],
+ &c_temp[3], a_sum_temp.data());
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+ } else {
+ c[0] = _mm256_setzero_si256();
+ c[1] = _mm256_setzero_si256();
+ c[2] = _mm256_setzero_si256();
+ c[3] = _mm256_setzero_si256();
+ }
+
+ if (K - k == 3) {
+ madd_epi16x3_packed<SUM_A>(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k,
+ &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3],
+ a_sum_temp.data());
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
+ c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
+ c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
+ c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
+
+ if (K - k == 0 || K - k == 3) {
+ c[0] = c_temp[0];
+ c[1] = c_temp[1];
+ c[2] = c_temp[2];
+ c[3] = c_temp[3];
+ } else {
+ if (K - k == 1) {
+ madd_epi16_packed<SUM_A>(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3],
+ a_sum_temp.data());
+ } else if (K - k == 2) {
+ madd_epi16x2_packed<SUM_A>(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1],
+ &c[2], &c[3], a_sum_temp.data());
+ }
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ if (REMAINDER) {
+ for (int r = 0; r < remainder / 8; ++r) {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + r * 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)),
+ c[r]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]);
+ }
+ }
+ } else {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C),
+ _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)),
+ c[0]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 16),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 24),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]);
+ }
+ }
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
+ a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
+ a_sum[2] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
+ a_sum[3] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
+ }
+}
+
+template <bool SUM_A = false, bool REMAINDER = false>
+static inline __attribute__((always_inline))
+void inner_prod_3x3_packed_(const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder,
+ a_sum);
+}
+
+// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
+// row_offsets for each row because of depth-wise convolution
+template <bool FUSE_RELU, bool HAS_BIAS, bool PER_CHANNEL_QUANTIZATION>
+static inline __attribute__((always_inline)) void requantize_(
+ int32_t A_zero_point,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ __m256 multiplier_v = _mm256_setzero_ps();
+ if (!PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_set1_ps(*C_multiplier);
+ }
+
+ __m256i min_v = _mm256_set1_epi8(numeric_limits<uint8_t>::min());
+ __m256i max_v = _mm256_set1_epi8(numeric_limits<uint8_t>::max());
+
+ __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ int j = 0;
+ for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+ __m256i y_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
+ __m256i z_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
+
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
+ y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ col_offsets + j + 2 * VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
+ z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ col_offsets + j + 3 * VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
+ w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v);
+
+ if (HAS_BIAS) { // static if
+ x_v = _mm256_add_epi32(
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN);
+ }
+ __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
+ }
+ __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
+ }
+ __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for ( ; j < n / VLEN * VLEN; j += VLEN) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
+
+ if (HAS_BIAS) { // static if
+ x_v = _mm256_add_epi32(
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ x_clamped_v =
+ _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(C_uint8 + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for ( ; j < n; ++j) {
+ int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j];
+ if (HAS_BIAS) { // static if
+ raw += bias[j];
+ }
+
+ float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
+ long rounded = lrintf(ab) + C_zero_point;
+
+ C_uint8[j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS>
+static inline __attribute__((always_inline)) void
+requantize_(int32_t A_zero_point, float C_multiplier,
+ int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8,
+ int n, const int32_t *row_offsets, const int32_t *col_offsets,
+ const int32_t *bias) {
+ requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8,
+ n,
+ row_offsets,
+ col_offsets,
+ bias);
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS>
+static inline __attribute__((always_inline)) void
+requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
+ int32_t C_zero_point, const int32_t *C_int32,
+ uint8_t *C_uint8, int n, const int32_t *row_offsets,
+ const int32_t *col_offsets, const int32_t *bias) {
+ requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8,
+ n,
+ row_offsets,
+ col_offsets,
+ bias);
+}
+
+template <bool REMAINDER>
+static inline __attribute__((always_inline)) __m256i
+load_a(const uint8_t* A, __m256i mask_v) {
+ if (REMAINDER) {
+ return _mm256_maskload_epi32(reinterpret_cast<const int *>(A), mask_v);
+ } else {
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(A));
+ }
+}
+
+template <bool SUM_A, bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void
+inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
+ const uint8_t *A, int32_t A_zero_point, const int8_t *Bp,
+ const int32_t *B_zero_point, int32_t *C, int remainder,
+ int32_t *row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ array<__m256i, 9> a_v = {
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ };
+
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
+ }
+ }
+
+ array<__m256i, 4> a_sum;
+ inner_prod_3x3_packed_<SUM_A, REMAINDER>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp), C, remainder,
+ a_sum.data());
+ if (SUM_A) {
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <bool SUM_A, bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void
+inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
+ int w_in, const uint8_t *A, int32_t A_zero_point,
+ const int8_t *Bp, const int32_t *B_zero_point,
+ int32_t *C, int remainder, int32_t *row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ array<__m256i, 8> a_v;
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
+ }
+ }
+ }
+
+ array<__m256i, 4> a_sum;
+ inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(),
+ reinterpret_cast<const __m256i *>(Bp),
+ C, remainder, a_sum.data());
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
+ }
+ }
+ }
+
+ array<__m256i, 4> a_sum_temp;
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 8, C, remainder,
+ a_sum_temp.data());
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 16, C, remainder,
+ a_sum_temp.data());
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 24, C, remainder,
+ a_sum_temp.data());
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <bool SUM_A, bool FUSE_RELU>
+static inline __attribute__((always_inline))
+void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t *col_offsets,
+ const int32_t *bias)
+{
+ constexpr int S = 3;
+ constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3_packed_<SUM_A>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, &B_zero_point,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3_packed_<SUM_A, true>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, &B_zero_point,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_<FUSE_RELU, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+template <bool SUM_A, bool FUSE_RELU>
+static inline __attribute__((always_inline))
+void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t *col_offsets,
+ const int32_t *bias)
+{
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<SUM_A>(
+ T, H, W, K, t_in, h_in, w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
+ Bp + k * 28, &B_zero_point,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<SUM_A, true>(
+ T, H, W, K, t_in, h_in, w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
+ Bp + k * 28, &B_zero_point,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_<FUSE_RELU, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+template <bool SUM_A>
+static inline __attribute__((always_inline)) void
+depthwise_3x3_per_channel_quantization_kernel_(
+ int H, int W, int K, int h, int w, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ const int32_t *B_zero_point, const int8_t *Bp,
+ const float *C_multiplier, int32_t C_zero_point,
+ int32_t *C_int32, uint8_t *C_uint8,
+ int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) {
+ constexpr int S = 3;
+ constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3_packed_<SUM_A, false/*remainder*/, true/*per-channel*/>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, B_zero_point + k,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3_packed_<SUM_A, true/*remainder*/, true/*per-channel*/>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, B_zero_point + k,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_per_channel_<false, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+static pair<int, int> closest_factors_(int n) {
+ int a = (int)std::sqrt(n);
+ while (n % a != 0) {
+ a--;
+ }
+ return { a, n / a }; // a <= n / a
+}
+
+// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
+// This implemntation should be general enough to handle not just 3x3 but other
+// filter shapes by parameterizing with R and S but restricting it to just 3x3
+// for now.
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+static inline __attribute__((always_inline))
+void depthwise_3x3_pad_1_(int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ int32_t B_zero_point, const Packed3x3ConvMatrix &B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ const int32_t *col_offsets, const int32_t *bias,
+ int thread_id, int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ if (h_begin == 0) {
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
+ if (w_begin == 0) {
+ w = 0;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ if (h_end == H_OUT) {
+ h = H_OUT - 1;
+ w = 0;
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+ } // for each n
+};
+
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+static inline __attribute__((always_inline))
+void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix &B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ const int32_t *col_offsets, const int32_t *bias,
+ int thread_id, int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ C_temp =
+ FUSE_RESCALE
+ ? C_int32
+ : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point,
+ A_base, B_zero_point, Bp, C_multiplier,
+ C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets,
+ bias);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+template <bool FUSE_RESCALE = true>
+static inline __attribute__((always_inline)) void
+depthwise_3x3_per_channel_quantization_pad_1_(
+ int N, int H, int W, int K, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point,
+ const Packed3x3ConvMatrix &B, const float *C_multiplier,
+ int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8,
+ const int32_t *col_offsets, const int32_t *bias, int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ if (h_begin == 0) {
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
+ if (w_begin == 0) {
+ w = 0;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ if (h_end == H_OUT) {
+ h = H_OUT - 1;
+ w = 0;
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+ } // for each n
+};
+
+// assumption: W > 3 and H > 3
+void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const Packed3x3ConvMatrix& B,
+ int32_t* C,
+ int thread_id,
+ int num_threads) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ }
+}
+
+void depthwise_3x3_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads,
+ bool fuse_relu) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (fuse_relu) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+ } else {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+ }
+}
+
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const Packed3x3x3ConvMatrix& B,
+ int32_t* C,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+}
+
+static void depthwise_3x3x3_pad_1_(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+}
+
+static void depthwise_3x3x3_pad_1_relu_fused_(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+}
+
+void depthwise_3x3x3_pad_1(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id, int num_threads) {
+ // If we inline the following two functions, I see stack overflow.
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_relu_fused_(
+ N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A,
+ B_zero_point, B, C_multiplier, C_zero_point, C,
+ col_offsets, bias, thread_id, num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w,
+ A_zero_point, A, B_zero_point, B, C_multiplier,
+ C_zero_point, C, col_offsets, bias,
+ thread_id, num_threads);
+ }
+}
+
+void depthwise_3x3_per_channel_quantization_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
+ const float *C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+}
+
+} // namespace fbgemm2
diff --git a/src/FbgemmI8Depthwise.h b/src/FbgemmI8Depthwise.h
new file mode 100644
index 0000000..bc62c84
--- /dev/null
+++ b/src/FbgemmI8Depthwise.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <cstdint>
+
+namespace fbgemm2
+{
+
+// KERNEL_PROD is the product of all kernels.
+// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
+template <int KERNEL_PROD>
+class PackedDepthWiseConvMatrix
+{
+ public:
+ // smat in RSG layout
+ PackedDepthWiseConvMatrix(int K, const std::int8_t *smat);
+ virtual ~PackedDepthWiseConvMatrix();
+
+ const std::int8_t* PackedMat() const {
+ return pmat_;
+ }
+
+ private:
+ int K_;
+ std::int8_t* pmat_;
+}; // Packed3x3ConvMatrix
+
+using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
+using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * @params A The input image in NHWK layout
+ * @params Bp The pre-packed filter
+ */
+void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const Packed3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0, int num_threads = 1);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization.
+ */
+void depthwise_3x3_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp,
+ float C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ int thread_id = 0, int num_threads = 1, bool fuse_relu = false);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization and uses per-channel quantization.
+ */
+void depthwise_3x3_per_channel_quantization_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const std::int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
+ const float *C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ int thread_id = 0, int num_threads = 1);
+
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const Packed3x3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0, int num_threads = 1);
+
+void depthwise_3x3x3_pad_1(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp,
+ float C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ bool fuse_relu = false, int thread_id = 0, int num_threads = 1);
+
+} // namespace fbgemm2
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc
new file mode 100644
index 0000000..723a467
--- /dev/null
+++ b/src/FbgemmI8Spmdm.cc
@@ -0,0 +1,508 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8Spmdm.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cmath>
+#include <cstring>
+
+#include <immintrin.h>
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double spmdm_initial_time = 0.0;
+double spmdm_transpose_uint8_time = 0.0;
+double spmdm_transpose_32xN_time = 0.0;
+double spmdm_compute_time = 0.0;
+double spmdm_transpose_Nx32_time = 0.0;
+double spmdm_run_time = 0.0;
+#endif
+
+using namespace std;
+
+namespace fbgemm2 {
+
+CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols)
+ : num_rows_(num_of_rows),
+ colptr_(num_of_cols + 1),
+ hyper_sparse_(false),
+ old_nnz_(-1) {}
+
+double CompressedSparseColumn::Density() const {
+ return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols());
+}
+
+bool CompressedSparseColumn::IsHyperSparse() const {
+ if (NumOfNonZeros() != old_nnz_) {
+ old_nnz_ = NumOfNonZeros();
+ // The number of non-zero per row is very small.
+ hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08;
+ }
+
+ return hyper_sparse_;
+}
+
+static void transpose_8rows(
+ int N,
+ const uint8_t* src,
+ int ld_src,
+ uint8_t* dst,
+ int ld_dst) {
+ constexpr int M = 8;
+ int j;
+ // vectorized loop
+ for (j = 0; j < N / 32 * 32; j += 32) {
+ // a : a0 a1 ... a31
+ // b : b0 b1 ... b31
+ // c : c0 c1 ... c31
+ // d : d0 d1 ... d31
+ __m256i a = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 0 * ld_src));
+ __m256i b = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 1 * ld_src));
+ __m256i c = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 2 * ld_src));
+ __m256i d = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 3 * ld_src));
+ __m256i e = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 4 * ld_src));
+ __m256i f = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 5 * ld_src));
+ __m256i g = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 6 * ld_src));
+ __m256i h = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 7 * ld_src));
+
+ // even-odd interleaving
+ // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23
+ // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31
+ // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23
+ // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31
+ __m256i ab_lo = _mm256_unpacklo_epi8(a, b);
+ __m256i ab_hi = _mm256_unpackhi_epi8(a, b);
+ __m256i cd_lo = _mm256_unpacklo_epi8(c, d);
+ __m256i cd_hi = _mm256_unpackhi_epi8(c, d);
+ __m256i ef_lo = _mm256_unpacklo_epi8(e, f);
+ __m256i ef_hi = _mm256_unpackhi_epi8(e, f);
+ __m256i gh_lo = _mm256_unpacklo_epi8(g, h);
+ __m256i gh_hi = _mm256_unpackhi_epi8(g, h);
+
+ // 4-row interleaving but permuted at 128-bit granularity
+ // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19
+ // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23
+ // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27
+ // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31
+ __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo);
+ __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo);
+ __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi);
+ __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi);
+ __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo);
+ __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo);
+ __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi);
+ __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi);
+
+ // 8-row interleaving
+ __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0);
+ __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0);
+ __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1);
+ __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1);
+ __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2);
+ __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2);
+ __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3);
+ __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3);
+
+ // Storing with 128-bit lanes are permuted so that everything is in order
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst),
+ _mm256_castsi256_si128(y0));
+ *reinterpret_cast<int64_t*>(dst + (j + 1) * ld_dst) =
+ _mm256_extract_epi64(y0, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst),
+ _mm256_castsi256_si128(y1));
+ *reinterpret_cast<int64_t*>(dst + (j + 3) * ld_dst) =
+ _mm256_extract_epi64(y1, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst),
+ _mm256_castsi256_si128(y2));
+ *reinterpret_cast<int64_t*>(dst + (j + 5) * ld_dst) =
+ _mm256_extract_epi64(y2, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst),
+ _mm256_castsi256_si128(y3));
+ *reinterpret_cast<int64_t*>(dst + (j + 7) * ld_dst) =
+ _mm256_extract_epi64(y3, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst),
+ _mm256_castsi256_si128(y4));
+ *reinterpret_cast<int64_t*>(dst + (j + 9) * ld_dst) =
+ _mm256_extract_epi64(y4, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst),
+ _mm256_castsi256_si128(y5));
+ *reinterpret_cast<int64_t*>(dst + (j + 11) * ld_dst) =
+ _mm256_extract_epi64(y5, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst),
+ _mm256_castsi256_si128(y6));
+ *reinterpret_cast<int64_t*>(dst + (j + 13) * ld_dst) =
+ _mm256_extract_epi64(y6, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst),
+ _mm256_castsi256_si128(y7));
+ *reinterpret_cast<int64_t*>(dst + (j + 15) * ld_dst) =
+ _mm256_extract_epi64(y7, 1);
+ *reinterpret_cast<int64_t*>(dst + (j + 16) * ld_dst) =
+ _mm256_extract_epi64(y0, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 17) * ld_dst) =
+ _mm256_extract_epi64(y0, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 18) * ld_dst) =
+ _mm256_extract_epi64(y1, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 19) * ld_dst) =
+ _mm256_extract_epi64(y1, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 20) * ld_dst) =
+ _mm256_extract_epi64(y2, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 21) * ld_dst) =
+ _mm256_extract_epi64(y2, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 22) * ld_dst) =
+ _mm256_extract_epi64(y3, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 23) * ld_dst) =
+ _mm256_extract_epi64(y3, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 24) * ld_dst) =
+ _mm256_extract_epi64(y4, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 25) * ld_dst) =
+ _mm256_extract_epi64(y4, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 26) * ld_dst) =
+ _mm256_extract_epi64(y5, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 27) * ld_dst) =
+ _mm256_extract_epi64(y5, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 28) * ld_dst) =
+ _mm256_extract_epi64(y6, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 29) * ld_dst) =
+ _mm256_extract_epi64(y6, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 30) * ld_dst) =
+ _mm256_extract_epi64(y7, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 31) * ld_dst) =
+ _mm256_extract_epi64(y7, 3);
+ }
+
+ // scalar loop for remainder
+ for (; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ dst[j * ld_dst + i] = src[j + i * ld_src];
+ }
+ }
+}
+
+// TODO: fallback when AVX2 is not available
+void CompressedSparseColumn::SpMDM(
+ const block_type_t& block,
+ const uint8_t* A,
+ int lda,
+ bool accumulation,
+ int32_t* C,
+ int ldc) const {
+ int K = NumOfRows();
+ int N = block.col_size;
+
+ if (K == 0 || N == 0) {
+ return;
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
+ t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+ t_very_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ uint8_t A_buffer[K * 32] __attribute__((aligned(64)));
+ int32_t C_buffer[N * 32] __attribute__((aligned(64)));
+
+ // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
+ // each non-zero in B, we'd need to access the corresponding column in A.
+ // This results in strided access, which we want to avoid.
+ // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T
+
+ if (IsHyperSparse()) {
+ // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications.
+ // If NNZ/K is small, it's not worth doing transpose so we just use this
+ // scalar loop.
+ if (!accumulation) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size;
+ ++j) {
+ C[(i - block.row_start) * ldc + j - block.col_start] = 0;
+ }
+ }
+ }
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = block.row_start; i < block.row_start + block.row_size;
+ ++i) {
+ C[(i - block.row_start) * ldc + j - block.col_start] +=
+ A[i * lda + row] * w;
+ }
+ }
+ } // for each column of B
+ return;
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_initial_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Take 32 rows at a time
+ int i_end = block.row_start + block.row_size;
+ for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
+ // Transpose 32 x K submatrix of A
+ if (i_end - i1 < 32) {
+ uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64)));
+ for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
+ transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
+ }
+
+ for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) {
+ memcpy(
+ A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t));
+ }
+ memset(
+ A_temp_buffer + (i_end - i1) * K,
+ 0,
+ (32 - (i_end - i1)) * K * sizeof(uint8_t));
+ for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) {
+ transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32);
+ }
+ } else {
+ for (int i2 = 0; i2 < 32; i2 += 8) {
+ transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
+ }
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_uint8_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ if (accumulation) {
+ // Transpose 32 x N submatrix of C to fill N x 32 C_buffer
+ transpose_simd(
+ std::min(32, i_end - i1),
+ N,
+ reinterpret_cast<const float*>(C + (i1 - block.row_start) * ldc),
+ ldc,
+ reinterpret_cast<float*>(C_buffer),
+ 32);
+ } else {
+ memset(C_buffer, 0, N * 32 * sizeof(int32_t));
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_32xN_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ for (int j = 0; j < block.col_size; ++j) {
+ int j_start = j + block.col_start;
+ int k = colptr_[j_start];
+ int k_end_aligned =
+ colptr_[j_start] + (colptr_[j_start + 1] - colptr_[j_start]) / 4 * 4;
+
+ for (; k < k_end_aligned; k += 4) {
+ __m256i w =
+ _mm256_set1_epi32(*(reinterpret_cast<const int32_t*>(&values_[k])));
+ array<__m256i, 4> a;
+ a[0] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32]));
+ a[1] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 1] * 32]));
+ a[2] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 2] * 32]));
+ a[3] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 3] * 32]));
+
+ __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
+ __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
+ __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
+ __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
+
+ a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
+ a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
+ a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
+ a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
+
+ array<__m256i, 4> ab;
+ ab[0] = _mm256_maddubs_epi16(a[0], w);
+ ab[1] = _mm256_maddubs_epi16(a[1], w);
+ ab[2] = _mm256_maddubs_epi16(a[2], w);
+ ab[3] = _mm256_maddubs_epi16(a[3], w);
+
+ __m256i one = _mm256_set1_epi16(1);
+ ab[0] = _mm256_madd_epi16(ab[0], one);
+ ab[1] = _mm256_madd_epi16(ab[1], one);
+ ab[2] = _mm256_madd_epi16(ab[2], one);
+ ab[3] = _mm256_madd_epi16(ab[3], one);
+
+ array<__m256i, 4> t;
+ t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
+ t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
+ t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
+ t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
+
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 0 * 8])),
+ t[0]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 1 * 8])),
+ t[1]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 2 * 8])),
+ t[2]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 3 * 8])),
+ t[3]));
+ }
+
+ int remainder = colptr_[j_start + 1] - k;
+ assert(remainder < 4);
+ if (remainder > 0) {
+ int32_t temp_w = 0;
+ for (int r = 0; r < remainder; ++r) {
+ (reinterpret_cast<int8_t*>(&temp_w))[r] = values_[k + r];
+ }
+ __m256i w = _mm256_set1_epi32(temp_w);
+ array<__m256i, 4> a;
+ a[0] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32]));
+ a[1] = remainder > 1
+ ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &A_buffer[rowidx_[k + 1] * 32]))
+ : _mm256_setzero_si256();
+ a[2] = remainder > 2
+ ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &A_buffer[rowidx_[k + 2] * 32]))
+ : _mm256_setzero_si256();
+ a[3] = _mm256_setzero_si256();
+
+ __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
+ __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
+ __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
+ __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
+
+ a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
+ a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
+ a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
+ a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
+
+ array<__m256i, 4> ab;
+ ab[0] = _mm256_maddubs_epi16(a[0], w);
+ ab[1] = _mm256_maddubs_epi16(a[1], w);
+ ab[2] = _mm256_maddubs_epi16(a[2], w);
+ ab[3] = _mm256_maddubs_epi16(a[3], w);
+
+ __m256i one = _mm256_set1_epi16(1);
+ ab[0] = _mm256_madd_epi16(ab[0], one);
+ ab[1] = _mm256_madd_epi16(ab[1], one);
+ ab[2] = _mm256_madd_epi16(ab[2], one);
+ ab[3] = _mm256_madd_epi16(ab[3], one);
+
+ array<__m256i, 4> t;
+ t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
+ t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
+ t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
+ t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
+
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 0 * 8])),
+ t[0]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 1 * 8])),
+ t[1]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 2 * 8])),
+ t[2]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 3 * 8])),
+ t[3]));
+ }
+ } // for each column of B
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_compute_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Transpose N x 32 C_buffer to fill 32 x N submatrix of C
+ transpose_simd(
+ N,
+ std::min(32, i_end - i1),
+ reinterpret_cast<const float*>(C_buffer),
+ 32,
+ reinterpret_cast<float*>(C + (i1 - block.row_start) * ldc),
+ ldc);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_Nx32_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt =
+ std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
+ .count();
+ spmdm_run_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
new file mode 100644
index 0000000..30160d1
--- /dev/null
+++ b/src/GenerateKernel.h
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <map>
+#include <tuple>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * @brief AVX2/AVX512 JIT assembly code generator.
+ * @tparam TA Type of matrix A.
+ * @tparam TB Type of matrix B.
+ * @tparam TC Type of matrix C.
+ * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or
+ * 32-bit (std::int32_t) accumulation.
+ */
+template <typename TA, typename TB, typename TC, typename accT>
+class CodeGenBase {
+ public:
+ using jit_micro_kernel_fp = void (*)(
+ TA* bufferA,
+ TB* bufferB,
+ TB* b_pf,
+ TC* bufferC,
+ int kc,
+ int ldc);
+
+ /**
+ * @brief Constructor for initializing AVX2/AVX512 registers.
+ */
+ CodeGenBase()
+ : CRegs_avx2_{x86::ymm0,
+ x86::ymm1,
+ x86::ymm2,
+ x86::ymm3,
+ x86::ymm4,
+ x86::ymm5,
+ x86::ymm6,
+ x86::ymm7,
+ x86::ymm8,
+ x86::ymm9,
+ x86::ymm10,
+ x86::ymm11},
+ CRegs_avx512_{
+ x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
+ x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
+ x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
+ x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
+ x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
+ x86::zmm25, x86::zmm26, x86::zmm27,
+ } {
+ // vector width in bits
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ vectorWidth_ = 512;
+ } else if (cpuinfo_has_x86_avx2()) {
+ vectorWidth_ = 256;
+ } else {
+ // TODO: Have default path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+ // vector width in elements
+ VLEN_ = vectorWidth_ / 8 * sizeof(TA);
+ }
+
+ /**
+ * @brief Get or Create the instructions for macro-kernel.
+ *
+ * If the problem size (mc, nc) and accumulation flag (accum) can be found in
+ * the code cache (a hash map), then get the macro-kernel instructions
+ * directly from it. Otherwise, create the instructions for macro-kernel, and
+ * store that into the code cache.
+ */
+ template <inst_set_t instSet>
+ jit_micro_kernel_fp
+ getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc, int32_t ldc);
+
+ /**
+ * @brief Generate instructions for initializing the C registers to 0.
+ */
+ template <inst_set_t instSet>
+ void initCRegs(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign = 4);
+
+ /**
+ * @brief Generate instructions for computing block in the rank-k update.
+ */
+ template <inst_set_t instSet>
+ void genComputeBlock(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign = 4);
+
+ /**
+ * @brief Generate instructions for storing the C registers back to the
+ * memory.
+ */
+ template <inst_set_t instSet>
+ void storeCRegs(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign = 4);
+
+ private:
+ asmjit::X86Ymm
+ CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ asmjit::X86Zmm
+ CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
+ int vectorWidth_; ///< Vector width in bits.
+ int VLEN_; ///< Vector width in elements.
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static thread_local std::map<std::tuple<bool, int, int>, jit_micro_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+};
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local std::map<
+ std::tuple<bool, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
new file mode 100644
index 0000000..2ffe3ab
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -0,0 +1,292 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX2 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for computing block in the rank-k update of 16-bit
+ * Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Ymm AReg = x86::ymm12;
+
+ asmjit::X86Ymm tmpReg = x86::ymm14;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ // broadcast A
+ a->vpbroadcastw(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ for (int j = 0; j < colRegs; ++j) {
+ a->vpmaddubsw(
+ tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vpaddsw(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ tmpReg,
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ // Prefetching is hurting performance in some cases
+ // because prefetch instructions itself consumes a slot
+ // in pipeline issue thus slowing down the kernel.
+ // if((i == rowRegs - 1) && j % 2 == 0){
+ // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t)));
+ //}
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for storing the C registers back to the memory in
+ * 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ asmjit::X86Xmm extractDest128 = x86::xmm15;
+ asmjit::X86Ymm extractDest256 = x86::ymm15;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ for (int j = 0; j < colRegs; ++j) {
+ for (int idx = 0; idx < 2; ++idx) {
+ a->vextracti128(
+ extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx);
+ a->vpmovsxwd(extractDest256, extractDest128);
+ asmjit::X86Mem destAddr = x86::dword_ptr(
+ a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
+ if (accum) {
+ a->vpaddd(extractDest256, extractDest256, destAddr);
+ }
+ a->vmovups(destAddr, extractDest256);
+ }
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR;
+ // constexpr int nRegBlockSize =
+ // PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc
new file mode 100644
index 0000000..e613cf1
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16_avx512.cc
@@ -0,0 +1,295 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Zmm AReg = x86::zmm29;
+
+ asmjit::X86Zmm tmpReg = x86::zmm30;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ // broadcast A
+ a->vpbroadcastw(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ for (int j = 0; j < colRegs; ++j) {
+ a->vpmaddubsw(
+ tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vpaddsw(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ tmpReg,
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ // Prefetching is hurting performance in some cases
+ // because prefetch instructions itself consumes a slot
+ // in pipeline issue thus slowing down the kernel.
+ // if((i == rowRegs - 1) && j % 2 == 0){
+ // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t)));
+ //}
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ asmjit::X86Ymm extractDest256 = x86::ymm31;
+ asmjit::X86Zmm extractDest512 = x86::zmm31;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ for (int j = 0; j < colRegs; ++j) {
+ for (int idx = 0; idx < 2; ++idx) {
+ a->vextracti32x8(
+ extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx);
+ a->vpmovsxwd(extractDest512, extractDest256);
+ asmjit::X86Mem destAddr = x86::dword_ptr(
+ a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
+ if (accum) {
+ a->vpaddd(extractDest512, extractDest512, destAddr);
+ }
+ a->vmovups(destAddr, extractDest512);
+ }
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
+ // constexpr int nRegBlockSize =
+ // PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
new file mode 100644
index 0000000..dc8c6d3
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -0,0 +1,310 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX2 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs_avx2_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for computing block in the rank-k update of 32-bit
+ * Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Ymm AReg = x86::ymm12;
+
+ // used for matrix B
+ asmjit::X86Ymm BReg = x86::ymm13;
+
+ // Contains 16-bit 1s
+ asmjit::X86Ymm oneReg = x86::ymm15;
+
+ // temporary register
+ asmjit::X86Ymm res1 = x86::ymm14;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpmaddubsw(res1, AReg, BReg);
+ a->vpmaddwd(res1, oneReg, res1);
+ a->vpaddd(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ res1,
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX2 instructions for storing the C registers back to the memory in
+ * 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ // temp register
+ asmjit::X86Ymm tmpReg = x86::ymm14;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)),
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX2 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
+
+ asmjit::X86Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->mov(C_Offset, 0);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs);
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc
new file mode 100644
index 0000000..5cd5684
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32_avx512.cc
@@ -0,0 +1,312 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ asmjit::X86Zmm BReg = x86::zmm30;
+
+ // Contains 16-bit 1s
+ asmjit::X86Zmm oneReg = x86::zmm29;
+
+ // temporary register
+ asmjit::X86Zmm res1 = x86::zmm28;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpmaddubsw(res1, AReg, BReg);
+ a->vpmaddwd(res1, oneReg, res1);
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ res1,
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ // temp register
+ asmjit::X86Zmm tmpReg = x86::zmm28;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
+
+ asmjit::X86Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->mov(C_Offset, 0);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs);
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
new file mode 100644
index 0000000..543d99b
--- /dev/null
+++ b/src/PackAMatrix.cc
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAMatrix<T, accT>::PackAMatrix(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const T* smat,
+ int32_t ld,
+ inpType* pmat,
+ int32_t groups,
+ accT zero_pt)
+ : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ if (!pmat) {
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+}
+
+template <typename T, typename accT>
+void PackAMatrix<T, accT>::pack(const block_type_t& block) {
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+
+ BaseType::packedBlock(block_p);
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ T* out = BaseType::getBuf();
+ if (tr) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = smat_[i + ld_ * j];
+ out[addr(i, j) - addr(block.row_start, block.col_start)] = val;
+ }
+ // zero fill
+ // Please note that we zero fill, not zero_pt fill, because for
+ // requantization original, i.e., not padded, dimensions are used. If we
+ // were to use padded dimensions for requantization, we would zero_pt
+ // fill.
+ // For example, consider the following dot product:
+ // A = .3(5-15), .3(20-15) //.3 is scale and 15 is zero_pt
+ // B = .4(1+10), .4(4+10) // .4 is scale and -10 is zero_pt
+ //
+ // numElements(A) = 2 and numElements(B) = 2
+ //
+ // Dot product is (real): -3*4.4+1.5*5.6 = -4.8
+ // Dot product is (quantized): 5*1+20*4 = 85
+ //
+ // requantization: .3*.4(85 - (5+20)*(-10) - (1+4)*(15) +
+ // numElements(A)*(15)(-10)) = -4.8
+ //
+ // In the above adding one more element zero in the quantized domain,
+ // i.e., the quantized vectors become:
+ // A_q = 5, 20, 0
+ // B_q = 1, 4, 0
+ //
+ // and requantization with numElements(A) = 2 will produce the same
+ // answer (-4.8).
+ //
+ // Also in the above adding one more element zero_pt in the quantized
+ // domain, i.e., the quantized vectors become:
+ // A_q = 5, 20, 15
+ // B_q = 1, 4, -10
+ //
+ // and requantization with numElements(A) = 3 will produce the same
+ // answer (-4.8).
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
+ out[addr(i, j) - addr(block.row_start, block.col_start)] = 0;
+ }
+ }
+ } else {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int buf_idx = i - block.row_start;
+ memcpy(
+ out + buf_idx * BaseType::blockColSize(),
+ smat_ + i * ld_ + block.col_start,
+ block.col_size * sizeof(T));
+ // zero fill
+ for (int j = block.col_size; j < block_p.col_size; ++j) {
+ out[buf_idx * BaseType::blockColSize() + j] = 0;
+ }
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAMatrix<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAMatrix<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template class PackAMatrix<uint8_t, int32_t>;
+template class PackAMatrix<uint8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
new file mode 100644
index 0000000..7012289
--- /dev/null
+++ b/src/PackAWithIm2Col.cc
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithIm2Col<T, accT>::PackAWithIm2Col(
+ const conv_param_t& conv_p,
+ const T* sdata,
+ inpType* pmat,
+ int32_t zero_pt,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>(
+ conv_p.MB * conv_p.OH * conv_p.OW,
+ conv_p.KH * conv_p.KW * conv_p.IC,
+ pmat,
+ zero_pt),
+ conv_p_(conv_p),
+ sdata_(sdata) {
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = static_cast<T*>(
+ aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
+ }
+ if (row_offset) {
+ row_offset_ = row_offset;
+ } else {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = static_cast<int32_t*>(
+ aligned_alloc(64, BaseType::brow_ * sizeof(int32_t)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+
+ BaseType::packedBlock(block_p);
+ T* out = BaseType::getBuf();
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int n = i / (conv_p_.OH * conv_p_.OW);
+ int hw = i % (conv_p_.OH * conv_p_.OW);
+ int w = hw % conv_p_.OW;
+ int h = hw / conv_p_.OW;
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ int c = j % conv_p_.IC;
+ int rs = j / conv_p_.IC;
+ int s = rs % conv_p_.KW;
+ int r = rs / conv_p_.KW;
+
+ int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
+ int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
+ // Please note that padding for convolution should be filled with zero_pt
+ if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = BaseType::zeroPoint();
+ } else {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = sdata_
+ [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c];
+ }
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = 0;
+ }
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[ r * BaseType::blockColSize() + c ];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithIm2Col<uint8_t, int32_t>;
+template class PackAWithIm2Col<uint8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
new file mode 100644
index 0000000..30d94f8
--- /dev/null
+++ b/src/PackBMatrix.cc
@@ -0,0 +1,144 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackBMatrix<T, accT>::PackBMatrix(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const T* smat,
+ int32_t ld,
+ inpType* pmat,
+ int32_t groups,
+ accT zero_pt)
+ : PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB;
+ row_interleave_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // Error
+ assert(0 && "unknown architecure");
+ }
+ block_type_t block{0, BaseType::numRows(), 0, BaseType::numCols()};
+ BaseType::packedBlock(block);
+ if (!pmat) {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = (T*)aligned_alloc(
+ 64,
+ BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() *
+ BaseType::bcol_ * sizeof(T));
+ }
+ pack(block);
+}
+
+template <typename T, typename accT>
+void PackBMatrix<T, accT>::pack(const block_type_t& block) {
+ assert((BaseType::blockRowSize() % row_interleave_) == 0);
+
+ BaseType::packedBlock(block);
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = tr ? smat_[i + ld_ * j] : smat_[i * ld_ + j];
+ out[addr(i, j) - addr(block.row_start, block.col_start)] =
+ tconv(val, out[addr(i, j)]);
+ }
+ }
+ // fill the remaining with zero.
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (int i = block.row_start + block.row_size;
+ i < (block.row_start + block.row_size + row_interleave_ - 1) /
+ row_interleave_ * row_interleave_;
+ ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; j++) {
+ out[addr(i, j) - addr(block.row_start, block.col_start)] =
+ tconv(0, out[addr(i, j)]);
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset = (r % BaseType::blockRowSize() / row_interleave_) *
+ BaseType::blockColSize() * row_interleave_ +
+ (c % BaseType::blockColSize()) * row_interleave_ + r % row_interleave_;
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+ std::cout << "block size:"
+ << "[" << BaseType::blockRowSize() << ", "
+ << BaseType::blockColSize() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
+ auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
+ : BaseType::blockRowSize();
+ for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
+ std::cout << "block:" << nr << ", " << nc << std::endl;
+ auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol()
+ : BaseType::blockColSize();
+ for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
+ ++r) {
+ for (auto c = 0; c < cols * row_interleave_; ++c) {
+ T val =
+ out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
+ BaseType::blockColSize() +
+ nc * BaseType::blockRowSize() * BaseType::blockColSize() +
+ r * BaseType::blockColSize() * row_interleave_ + c];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+ }
+ }
+}
+
+template class PackBMatrix<int8_t, int32_t>;
+template class PackBMatrix<int8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
new file mode 100644
index 0000000..85000ac
--- /dev/null
+++ b/src/PackMatrix.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <iomanip>
+#include <stdexcept>
+#include <type_traits>
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename PT, typename inpType, typename accType>
+PackMatrix<PT, inpType, accType>::PackMatrix(
+ int32_t rows,
+ int32_t cols,
+ inpType* buf,
+ int32_t zero_pt)
+ : buf_(buf), nrows_(rows), ncols_(cols), zero_pt_(zero_pt) {
+ bufAllocatedHere_ = false;
+ if (!cpuinfo_initialize()) {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template <typename PT, typename inpType, typename accType>
+int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) {
+ if (cpuinfo_has_x86_avx512f()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else if (cpuinfo_has_x86_avx2()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ return -1;
+}
+
+// int32 accumulation
+template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
+
+// int16 accumulation
+template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ uint8_t,
+ int16_t>;
+
+template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackWithQuantRowOffset.cc b/src/PackWithQuantRowOffset.cc
new file mode 100644
index 0000000..74eaade
--- /dev/null
+++ b/src/PackWithQuantRowOffset.cc
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <cmath>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+#include <stdexcept>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const float* smat,
+ int32_t ld,
+ inpType* pmat,
+ float scale,
+ int32_t zero_pt,
+ int32_t groups,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ scale_(scale),
+ G_(groups),
+ row_offset_(row_offset) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ rowOffsetAllocatedHere = false;
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+ if (!row_offset_) {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = reinterpret_cast<int32_t*>(
+ aligned_alloc(64, BaseType::brow_ * sizeof(accT)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) {
+ assert(block.row_start % BaseType::blockRowSize() == 0);
+ assert(block.col_start % BaseType::blockColSize() == 0);
+ assert(block.row_size <= BaseType::blockRowSize());
+ assert(block.col_size <= BaseType::blockColSize());
+
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+ assert(block_p.col_size <= BaseType::blockColSize());
+ BaseType::packedBlock(block_p);
+
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ // accumulate into row offset?
+ bool row_offset_acc = (block.col_start != 0);
+ int32_t* row_offset_buf = getRowOffsetBuffer();
+
+ float smat_transposed[block.row_size * block.col_size];
+ if (tr) {
+ transpose_simd(
+ block.col_size,
+ block.row_size,
+ smat_ + block.col_start * ld_ + block.row_start,
+ ld_,
+ smat_transposed,
+ block.col_size);
+ }
+ const float* smat_temp =
+ tr ? smat_transposed : smat_ + block.row_start * ld_ + block.col_start;
+ int32_t ld_temp = tr ? block.col_size : ld_;
+
+#if defined(__AVX2__) && defined(__FMA__)
+ constexpr int VLEN = 8;
+ __m256 inverse_scale_v = _mm256_set1_ps(1.0f / scale_);
+ __m256i shuffle_mask_v = _mm256_set_epi8(
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
+ __m256i permute_mask_v = _mm256_set_epi32(
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+#endif
+
+ for (int i = 0; i < block.row_size; ++i) {
+ int32_t row_sum = row_offset_acc ? row_offset_buf[i] : 0;
+ int j = 0;
+#if defined(__AVX2__) && defined(__FMA__)
+ static_assert(
+ std::is_same<T, uint8_t>::value,
+ "PackAWithQuantRowOffset<T, accT>::pack only works for T == uint8_t");
+ for (; j < block.col_size / VLEN * VLEN; j += VLEN) {
+ __m256 val_v = _mm256_loadu_ps(smat_temp + i * ld_temp + j);
+ __m256 transformed_v = _mm256_fmadd_ps(
+ val_v, inverse_scale_v, _mm256_set1_ps(BaseType::zeroPoint()));
+ __m256 clipped_v = _mm256_max_ps(
+ _mm256_set1_ps(std::numeric_limits<uint8_t>::min()),
+ _mm256_min_ps(
+ transformed_v,
+ _mm256_set1_ps(std::numeric_limits<uint8_t>::max())));
+ __m256i res_v = _mm256_cvtps_epi32(clipped_v);
+
+ // An instruction sequence to save 8 32-bit integers as 8 8-bit integers
+ res_v = _mm256_shuffle_epi8(res_v, shuffle_mask_v);
+ res_v = _mm256_permutevar8x32_epi32(res_v, permute_mask_v);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(out + i * BaseType::blockColSize() + j),
+ _mm256_castsi256_si128(res_v));
+
+ for (int j2 = j; j2 < j + VLEN; ++j2) {
+ row_sum += out[i * BaseType::blockColSize() + j2];
+ }
+ }
+#endif
+ for (; j < block.col_size; ++j) {
+ float val = smat_temp[i * ld_temp + j];
+ float transformed = val / scale_ + BaseType::zeroPoint();
+ float clipped = std::min<float>(
+ std::max<float>(transformed, std::numeric_limits<uint8_t>::min()),
+ std::numeric_limits<uint8_t>::max());
+ T res = round(clipped);
+ row_sum += res;
+ out[i * BaseType::blockColSize() + j] = res;
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (; j < block_p.col_size; ++j) {
+ out[i * BaseType::blockColSize() + j] = 0;
+ }
+ row_offset_buf[i] = row_sum;
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAWithQuantRowOffset<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithQuantRowOffset<uint8_t, int32_t>;
+
+} // namespace fbgemm2
diff --git a/src/PackWithRowOffset.cc b/src/PackWithRowOffset.cc
new file mode 100644
index 0000000..8722723
--- /dev/null
+++ b/src/PackWithRowOffset.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cassert>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+#include <stdexcept>
+#include <cpuinfo.h>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithRowOffset<T, accT>::PackAWithRowOffset(
+ matrix_op_t trans,
+ uint32_t nRow,
+ uint32_t nCol,
+ const T* smat,
+ uint32_t ld,
+ inpType* pmat,
+ uint32_t groups,
+ int32_t zero_pt,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups),
+ row_offset_(row_offset) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ rowOffsetAllocatedHere = false;
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ //TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+ if (!row_offset_) {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = static_cast<int32_t*>(aligned_alloc(64,
+ BaseType::brow_ * sizeof(int32_t)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
+ assert(block.row_start % BaseType::blockRowSize() == 0);
+ assert(block.col_start % BaseType::blockColSize() == 0);
+ assert(block.row_size <= BaseType::blockRowSize());
+ assert(block.col_size <= BaseType::blockColSize());
+
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+ assert(block_p.col_size <= BaseType::blockColSize());
+ BaseType::packedBlock(block_p);
+
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ // accumulate into row offset?
+ bool row_offset_acc = (block.col_start != 0);
+ int32_t* row_offset_buf = getRowOffsetBuffer();
+ if (tr) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int32_t row_sum = row_offset_acc ?
+ row_offset_buf[i - block.row_start] : 0;
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = smat_[i + ld_ * j];
+ row_sum += val;
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = val;
+ }
+ row_offset_buf[i - block.row_start] = row_sum;
+ // zero fill
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size; ++j) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = 0;
+ }
+ }
+ } else {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int buf_idx = i - block.row_start;
+ memcpy(
+ out + buf_idx * BaseType::blockColSize(),
+ smat_ + i * ld_ + block.col_start,
+ block.col_size * sizeof(T));
+ // zero fill
+ for (int j = block.col_size; j < block_p.col_size; ++j) {
+ out[buf_idx * BaseType::blockColSize() + j] = 0;
+ }
+ int32_t row_sum = row_offset_acc ?
+ row_offset_buf[i - block.row_start] : 0;
+ __m256i sum_v = _mm256_setzero_si256();
+ __m256i one_epi16_v = _mm256_set1_epi16(1);
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ for (int j = block.col_start;
+ j < block.col_start + block.col_size / 32 * 32;
+ j += 32) {
+ __m256i src_v = _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(smat_ + i * ld_ + j));
+ sum_v = _mm256_add_epi32(
+ sum_v,
+ _mm256_madd_epi16(
+ _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
+ }
+ for (int j = block.col_start + block.col_size / 32 * 32;
+ j < block.col_start + block.col_size;
+ ++j) {
+ row_sum += smat_[i * ld_ + j];
+ }
+ alignas(64) std::array<int32_t, 8> temp;
+ _mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v);
+ for (int k = 0; k < 8; ++k) {
+ row_sum += temp[k];
+ }
+ row_offset_buf[i - block.row_start] = row_sum;
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAWithRowOffset<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() {
+ if(cpuinfo_initialize()){
+ if (cpuinfo_has_x86_avx512f()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ //TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithRowOffset<uint8_t, int32_t>;
+template class PackAWithRowOffset<uint8_t, int16_t>;
+
+} // namespace fbgemm2
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
new file mode 100644
index 0000000..9aedc88
--- /dev/null
+++ b/src/RefImplementations.cc
@@ -0,0 +1,608 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "RefImplementations.h"
+
+#include <cassert>
+#include <cmath>
+
+using namespace std;
+
+namespace fbgemm2 {
+
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const int32_t* inp,
+ uint8_t* out,
+ int32_t C_multiplier,
+ int32_t C_right_shift,
+ int32_t C_zero_point,
+ int32_t A_zero_point,
+ int32_t B_zero_point,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu) {
+ int64_t nudge = 1ll << std::max(0, C_right_shift - 1);
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ int32_t raw = inp[i * ld + j];
+ raw -= A_zero_point * col_offsets[j];
+ raw -= B_zero_point * row_offsets[i];
+ if (bias) {
+ raw += bias[j];
+ }
+
+ int64_t ab_64 =
+ static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier);
+ int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point;
+
+ out[i * ld + j] = std::max(
+ fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+ }
+}
+
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const int32_t* inp,
+ uint8_t* out,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t A_zero_point,
+ int32_t B_zero_point,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu) {
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ int32_t raw = inp[i * ld + j];
+ raw -= A_zero_point * col_offsets[j];
+ raw -= B_zero_point * row_offsets[i];
+ if (bias) {
+ raw += bias[j];
+ }
+
+ float result = raw * C_multiplier;
+ long rounded = lrintf(result) + C_zero_point;
+ out[i * ld + j] = std::max(
+ fuse_relu ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+ }
+}
+
+void matmul_u8i8acc32_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const uint8_t* Aint8,
+ const int8_t* Bint8,
+ int32_t* Cint32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += static_cast<int32_t>(Aint8[i * lda + k]) *
+ static_cast<int32_t>(Bint8[k * ldb + j]);
+ }
+ Cint32[i * ldc + j] = sum;
+ }
+ }
+}
+
+void matmul_u8i8acc16_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ int brow,
+ const uint8_t* Aint8,
+ const int8_t* Bint8,
+ int32_t* Cint32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0, sum_32bit = 0;
+ for (int k = 0; k < K; k += 2) {
+ int a0 = Aint8[i * lda + k];
+ int b0 = Bint8[k * ldb + j];
+ int a1 = 0, b1 = 0;
+ if (k + 1 < K) {
+ a1 = Aint8[i * lda + k + 1];
+ b1 = Bint8[(k + 1) * ldb + j];
+ }
+ sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1));
+ if ((k % brow) == (brow - 2)) {
+ sum_32bit += sum;
+ sum = 0;
+ }
+ }
+ Cint32[i * ldc + j] = sum_32bit + sum;
+ }
+ }
+}
+
+void matmul_fp_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const float* Afp32,
+ const float* Bfp32,
+ float* Cfp32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ float sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += Afp32[i * lda + k] * Bfp32[k * ldb + j];
+ }
+ Cfp32[i * ldc + j] = sum;
+ }
+ }
+}
+
+void row_offsets_u8acc32_ref(
+ int M,
+ int K,
+ int ld,
+ const uint8_t* Aint8,
+ int32_t* row_offsets) {
+ // row offset
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += static_cast<int32_t>(Aint8[i * ld + k]);
+ }
+ row_offsets[i] = sum;
+ }
+}
+
+void col_offsets_with_zero_pt_s8acc32_ref(
+ int K,
+ int N,
+ int ld,
+ const int8_t* Bint8,
+ int32_t B_zero_point,
+ int32_t* col_offsets) {
+ for (int j = 0; j < N; ++j) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += Bint8[k * ld + j];
+ }
+ col_offsets[j] = sum - B_zero_point * K;
+ }
+}
+
+void spmdm_ref(
+ int M,
+ const uint8_t* A,
+ int lda,
+ fbgemm2::CompressedSparseColumn& B,
+ bool accumulation,
+ int32_t* C,
+ int ldc) {
+ int N = B.NumOfCols();
+ if (!accumulation) {
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ C[i * ldc + j] = 0;
+ }
+ }
+ }
+ for (int j = 0; j < N; ++j) {
+ for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) {
+ int row = B.RowIdx()[k];
+ int w = B.Values()[k];
+ for (int i = 0; i < M; ++i) {
+ C[i * ldc + j] += A[i * lda + row] * w;
+ }
+ }
+ } // for each column of B
+}
+
+int32_t clip_16bit(int32_t x) {
+ if (x > std::numeric_limits<int16_t>::max()) {
+ return std::min<int>(std::numeric_limits<int16_t>::max(), x);
+ } else if (x < std::numeric_limits<int16_t>::min()) {
+ return std::max<int>(std::numeric_limits<int16_t>::min(), x);
+ } else {
+ return x;
+ }
+}
+
+void im2col_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao) {
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int r = 0; r < conv_p.KH; ++r) {
+ int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
+ for (int s = 0; s < conv_p.KW; ++s) {
+ int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ // Ai: NHWC: NH_0W_0 x C_0
+ std::uint8_t val =
+ h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW
+ ? A_zero_point
+ : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
+ c];
+ // Ao: NHWC: NH_1W_1 x RSC_0
+ Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
+ conv_p.KW +
+ s) *
+ conv_p.IC +
+ c] = val;
+ } // for each c
+ } // for each s
+ } // for each r
+ } // for each w
+ } // for each h
+ } // for each n
+}
+
+void conv_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C) {
+ // filters are assumed to be in RSCK format
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int sum = 0;
+ for (int r = 0; r < conv_p.KH; ++r) {
+ int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
+ for (int s = 0; s < conv_p.KW; ++s) {
+ int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
+ w_in >= conv_p.IW
+ ? A_zero_point
+ : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) *
+ conv_p.IC +
+ c];
+ int b =
+ B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k];
+ sum += a * b;
+ } // for each c
+ } // for each s
+ } // for each r
+ C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum;
+ } // for each k
+ } // for each w
+ } // for each h
+ } // for each n
+}
+
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int8_t* B,
+ int32_t* C) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ int b = B[(k * R + r) * S + s];
+ sum += a * b;
+ }
+ }
+ C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+};
+
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
+ depthwise_3x3_pad_1_ref(
+ N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
+
+ vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+
+ for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point,
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+void depthwise_3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
+ depthwise_3x3_pad_1_ref(
+ N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
+
+ vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+
+ for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier[k],
+ C_zero_point,
+ A_zero_point,
+ B_zero_point[k],
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int8_t* B,
+ int32_t* C) {
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int k_t = 0; k_t < K_T; ++k_t) {
+ int t_in = -PAD_P + t * stride_t + k_t;
+ for (int k_h = 0; k_h < K_H; ++k_h) {
+ int h_in = -PAD_T + h * stride_h + k_h;
+ for (int k_w = 0; k_w < K_W; ++k_w) {
+ int w_in = -PAD_L + w * stride_w + k_w;
+ int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
+ w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
+ int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w];
+ sum += a * b;
+ }
+ }
+ }
+ C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B,
+ C_int32.data());
+
+ vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int k_t = 0; k_t < K_T; ++k_t) {
+ int t_in = -PAD_P + t * stride_t + k_t;
+ for (int k_h = 0; k_h < K_H; ++k_h) {
+ int h_in = -PAD_T + h * stride_h + k_h;
+ for (int k_w = 0; k_w < K_W; ++k_w) {
+ int w_in = -PAD_L + w * stride_w + k_w;
+ int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
+ w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ }
+ row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
+ sum;
+ }
+ } // w
+ } // h
+ } // t
+ } // for each n
+
+ for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point,
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+} // namespace fbgemm2
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
new file mode 100644
index 0000000..e9eaeed
--- /dev/null
+++ b/src/RefImplementations.h
@@ -0,0 +1,268 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/FbgemmI8Spmdm.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Reference implementation of requantization step.
+ * int32 multiplier
+ * @params bias can be nullptr
+ */
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const std::int32_t* inp,
+ std::uint8_t* out,
+ std::int32_t C_multiplier,
+ std::int32_t C_right_shift,
+ std::int32_t C_zero_point,
+ std::int32_t A_zero_point,
+ std::int32_t B_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false);
+
+/**
+ * @brief Reference implementation of requantization step.
+ * float multiplier
+ * @params bias can be nullptr
+ */
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const std::int32_t* inp,
+ std::uint8_t* out,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t A_zero_point,
+ std::int32_t B_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false);
+
+/**
+ * @brief Reference implementation of matrix multiply with uint8 for A,
+ * int8 for B, and 32-bit accumulation.
+ */
+void matmul_u8i8acc32_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const std::uint8_t* Aint8,
+ const std::int8_t* Bint8,
+ std::int32_t* Cint32);
+
+/**
+ * @brief Reference implementation of matrix multiply with uint 8 for A,
+ * int8 for B, and 16-bit accumulation.
+ */
+void matmul_u8i8acc16_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ int brow,
+ const std::uint8_t* Aint8,
+ const std::int8_t* Bint8,
+ std::int32_t* Cint32);
+
+/**
+ * @brief Reference implementation of matrix multiply with fp32 (single
+ * precision) floating point number.
+ */
+void matmul_fp_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const float* Afp32,
+ const float* Bfp32,
+ float* Cfp32);
+
+/**
+ * @brief Reference implementation to compute row_offsets (sums of rows of A).
+ */
+void row_offsets_u8acc32_ref(
+ int M,
+ int K,
+ int ld,
+ const std::uint8_t* Aint8,
+ std::int32_t* row_offsets);
+
+/**
+ * @brief Reference implementation to compute adjusted col_offsets (sum of
+ * columns of B and adjusted with B_zero_point)
+ */
+void col_offsets_with_zero_pt_s8acc32_ref(
+ int K,
+ int N,
+ int ld,
+ const std::int8_t* Bint8,
+ std::int32_t B_zero_point,
+ std::int32_t* col_offsets);
+
+/**
+ * @brief Reference implementation of SPMDM (sparse matrix times dense matrix).
+ */
+void spmdm_ref(
+ int M,
+ const std::uint8_t* A,
+ int lda,
+ CompressedSparseColumn& B,
+ bool accumulation,
+ std::int32_t* C,
+ int ldc);
+
+/*
+ * @brief Trim a 32-bit integer to a 16-bit integer.
+ */
+int32_t clip_16bit(int32_t x);
+
+/*
+ * @brief Reference implementation of convolution operation.
+ * The activations A are assumed to be in NHiWiC format.
+ * The filters B are assumed to be in RSCK format.
+ * The output C is assumed to be in NHoWoC format.
+ */
+void conv_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of im2col operation.
+ * The input A is assumed to be in NHiWiC format.
+ * The output A is assumed to be in NHoWoRSC format.
+ */
+void im2col_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1.
+ */
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1, followed by requantization. (the same scaling factors and
+ * zero points for each channel).
+ */
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const std::int8_t* B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1, followed by requantization. (different scaling factors
+ * and zero points for each channel).
+ */
+void depthwise_3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const std::int8_t* B,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+/*
+ * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
+ * filter and padding size 1.
+ */
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
+ * filter and padding size 1, followed by requantization.
+ */
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const std::int8_t* B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+} // namespace fbgemm2
diff --git a/src/Utils.cc b/src/Utils.cc
new file mode 100644
index 0000000..10ab469
--- /dev/null
+++ b/src/Utils.cc
@@ -0,0 +1,357 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/Utils.h"
+#include <cpuinfo.h>
+#include <immintrin.h>
+#include <cassert>
+#include <cinttypes>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <limits>
+#include <stdexcept>
+
+namespace fbgemm2 {
+
+/**
+ * @brief Compare the reference and test result matrix to check the correctness.
+ * @param ref The buffer for the reference result matrix.
+ * @param test The buffer for the test result matrix.
+ * @param m The height of the reference and test result matrix.
+ * @param n The width of the reference and test result matrix.
+ * @param ld The leading dimension of the reference and test result matrix.
+ * @param max_mismatches_to_report The maximum number of tolerable mismatches to
+ * report.
+ * @param atol The tolerable error.
+ * @retval false If the number of mismatches for reference and test result
+ * matrix exceeds max_mismatches_to_report.
+ * @retval true If the number of mismatches for reference and test result matrix
+ * is tolerable.
+ */
+template <typename T>
+int compare_buffers(
+ const T* ref,
+ const T* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol /*=1e-3*/) {
+ size_t mismatches = 0;
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ T reference = ref[i * ld + j], actual = test[i * ld + j];
+ if (std::abs(reference - actual) > atol) {
+ std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl;
+ if (std::is_integral<T>::value) {
+ std::cout << "\t reference:" << static_cast<int64_t>(reference)
+ << " test:" << static_cast<int64_t>(actual) << std::endl;
+ } else {
+ std::cout << "\t reference:" << reference << " test:" << actual
+ << std::endl;
+ }
+
+ mismatches++;
+ if (mismatches > max_mismatches_to_report) {
+ return 1;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+
+/**
+ * @brief Print the matrix.
+ * @param op Transpose type of the matrix.
+ * @param R The height of the matrix.
+ * @param C The width of the matrix.
+ * @param ld The leading dimension of the matrix.
+ * @param name The prefix string before printing the matrix.
+ */
+template <typename T>
+void printMatrix(
+ matrix_op_t op,
+ const T* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name) {
+ // R: number of rows in op(inp)
+ // C: number of cols in op(inp)
+ // ld: leading dimension in inp
+ std::cout << name << ":"
+ << "[" << R << ", " << C << "]" << std::endl;
+ bool tr = (op == matrix_op_t::Transpose);
+ for (auto r = 0; r < R; ++r) {
+ for (auto c = 0; c < C; ++c) {
+ T res = tr ? inp[c * ld + r] : inp[r * ld + c];
+ if (std::is_integral<T>::value) {
+ std::cout << std::setw(5) << static_cast<int64_t>(res) << " ";
+ } else {
+ std::cout << std::setw(5) << res << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+}
+
+template int compare_buffers<float>(
+ const float* ref,
+ const float* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template int compare_buffers<int32_t>(
+ const int32_t* ref,
+ const int32_t* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template int compare_buffers<uint8_t>(
+ const uint8_t* ref,
+ const uint8_t* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template void printMatrix<float>(
+ matrix_op_t op,
+ const float* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<int8_t>(
+ matrix_op_t op,
+ const int8_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<uint8_t>(
+ matrix_op_t op,
+ const uint8_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<int32_t>(
+ matrix_op_t op,
+ const int32_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+
+
+/**
+ * @brief Reference implementation of matrix transposition: B = A^T.
+ * @param M The height of the matrix.
+ * @param N The width of the matrix.
+ * @param src The memory buffer of the source matrix A.
+ * @param ld_src The leading dimension of the source matrix A.
+ * @param dst The memory buffer of the destination matrix B.
+ * @param ld_dst The leading dimension of the destination matrix B.
+ */
+inline void transpose_ref(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ dst[i + j * ld_dst] = src[i * ld_src + j];
+ }
+ }
+}
+
+inline void
+transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {
+ // load from src to registers
+ // a : a0 a1 a2 a3
+ // b : b0 b1 b2 b3
+ // c : c0 c1 c2 c3
+ // d : d0 d1 d2 d3
+ __m128 a = _mm_loadu_ps(&src[0 * ld_src]);
+ __m128 b = _mm_loadu_ps(&src[1 * ld_src]);
+ __m128 c = _mm_loadu_ps(&src[2 * ld_src]);
+ __m128 d = _mm_loadu_ps(&src[3 * ld_src]);
+
+ // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
+ // a : a0 b0 c0 d0
+ // b : a1 b1 c1 d1
+ // c : a2 b2 c2 d2
+ // d : a3 b3 c3 d3
+ _MM_TRANSPOSE4_PS(a, b, c, d);
+
+ // store from registers to dst
+ _mm_storeu_ps(&dst[0 * ld_dst], a);
+ _mm_storeu_ps(&dst[1 * ld_dst], b);
+ _mm_storeu_ps(&dst[2 * ld_dst], c);
+ _mm_storeu_ps(&dst[3 * ld_dst], d);
+}
+inline void transpose_4x4(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 4 <= M; ib += 4) {
+ for (jb = 0; jb + 4 <= N; jb += 4) {
+ transpose_kernel_4x4_sse(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+inline void transpose_kernel_8x8_avx2(
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // load from src to registers
+ // a : a0 a1 a2 a3 a4 a5 a6 a7
+ // b : b0 b1 b2 b3 b4 b5 b6 b7
+ // c : c0 c1 c2 c3 c4 c5 c6 c7
+ // d : d0 d1 d2 d3 d4 d5 d6 d7
+ // e : e0 e1 e2 e3 e4 e5 e6 e7
+ // f : f0 f1 f2 f3 f4 f5 f6 f7
+ // g : g0 g1 g2 g3 g4 g5 g6 g7
+ // h : h0 h1 h2 h3 h4 h5 h6 h7
+ __m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
+ __m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
+ __m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
+ __m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
+ __m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
+ __m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
+ __m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
+ __m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
+
+ __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
+ __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
+ // unpacking and interleaving 32-bit elements
+ // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
+ // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
+ // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
+ // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
+ // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
+ // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
+ // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
+ // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
+ ab0145 = _mm256_unpacklo_ps(a, b);
+ ab2367 = _mm256_unpackhi_ps(a, b);
+ cd0145 = _mm256_unpacklo_ps(c, d);
+ cd2367 = _mm256_unpackhi_ps(c, d);
+ ef0145 = _mm256_unpacklo_ps(e, f);
+ ef2367 = _mm256_unpackhi_ps(e, f);
+ gh0145 = _mm256_unpacklo_ps(g, h);
+ gh2367 = _mm256_unpackhi_ps(g, h);
+
+ // shuffling the 32-bit elements
+ // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
+ // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
+ // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
+ // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
+ // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
+ // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
+ // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
+ // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
+ abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44);
+ abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee);
+ efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44);
+ efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee);
+ abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44);
+ abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee);
+ efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44);
+ efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee);
+
+ // shuffling 128-bit elements
+ // a : a0 b0 c0 d0 e0 f0 g0 h0
+ // b : a1 b1 c1 d1 e1 f1 g1 h1
+ // c : a2 b2 c2 d2 e2 f2 g2 h2
+ // d : a3 b3 c3 d3 e3 f3 g3 h3
+ // e : a4 b4 c4 d4 e4 f4 g4 h4
+ // f : a5 b5 c5 d5 e5 f5 g5 h5
+ // g : a6 b6 c6 d6 e6 f6 g6 h6
+ // h : a7 b7 c7 d7 e7 f7 g7 h7
+ a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02);
+ b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02);
+ c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02);
+ d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02);
+ e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13);
+ f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13);
+ g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13);
+ h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13);
+
+ // store from registers to dst
+ _mm256_storeu_ps(&dst[0 * ld_dst], a);
+ _mm256_storeu_ps(&dst[1 * ld_dst], b);
+ _mm256_storeu_ps(&dst[2 * ld_dst], c);
+ _mm256_storeu_ps(&dst[3 * ld_dst], d);
+ _mm256_storeu_ps(&dst[4 * ld_dst], e);
+ _mm256_storeu_ps(&dst[5 * ld_dst], f);
+ _mm256_storeu_ps(&dst[6 * ld_dst], g);
+ _mm256_storeu_ps(&dst[7 * ld_dst], h);
+}
+
+void transpose_8x8(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 8 <= M; ib += 8) {
+ for (jb = 0; jb + 8 <= N; jb += 8) {
+ transpose_kernel_8x8_avx2(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+void transpose_simd(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ transpose_16x16(M, N, src, ld_src, dst, ld_dst);
+ } else if (cpuinfo_has_x86_avx2()) {
+ transpose_8x8(M, N, src, ld_src, dst, ld_dst);
+ } else {
+ transpose_ref(M, N, src, ld_src, dst, ld_dst);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+} // namespace fbgemm2
diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc
new file mode 100644
index 0000000..b6bf413
--- /dev/null
+++ b/src/Utils_avx512.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "fbgemm/Utils.h"
+
+#include <immintrin.h>
+
+namespace fbgemm2 {
+
+inline void transpose_kernel_16x16_avx512(
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // load from src to registers
+ // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
+ // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
+ // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
+ // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
+ // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
+ // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
+ // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
+ // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
+ // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
+ // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
+ // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
+ // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
+ // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
+ // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
+ // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
+ // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
+ __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
+ __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
+ __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
+ __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
+ __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
+ __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
+ __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
+ __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
+ __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
+ __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
+ __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
+ __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
+ __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
+ __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
+ __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
+ __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
+
+ __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
+ // unpacking and interleaving 32-bit elements
+ // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13
+ // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15
+ // c0 d0 c1 d1 ...
+ // c2 d2 c3 d3 ...
+ // e0 f0 e1 f1 ...
+ // e2 f2 e3 f3 ...
+ // g0 h0 g1 h1 ...
+ // g2 h2 g3 h3 ...
+ // i0 ...
+ // i2 ...
+ // k0 ...
+ // k2 ...
+ // m0 ...
+ // m2 ...
+ // o0 ...
+ // o1 ...
+ ta = _mm512_unpacklo_ps(a, b);
+ tb = _mm512_unpackhi_ps(a, b);
+ tc = _mm512_unpacklo_ps(c, d);
+ td = _mm512_unpackhi_ps(c, d);
+ te = _mm512_unpacklo_ps(e, f);
+ tf = _mm512_unpackhi_ps(e, f);
+ tg = _mm512_unpacklo_ps(g, h);
+ th = _mm512_unpackhi_ps(g, h);
+ ti = _mm512_unpacklo_ps(i, j);
+ tj = _mm512_unpackhi_ps(i, j);
+ tk = _mm512_unpacklo_ps(k, l);
+ tl = _mm512_unpackhi_ps(k, l);
+ tm = _mm512_unpacklo_ps(m, n);
+ tn = _mm512_unpackhi_ps(m, n);
+ to = _mm512_unpacklo_ps(o, p);
+ tq = _mm512_unpackhi_ps(o, p);
+
+ // unpacking and interleaving 64-bit elements
+ // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12
+ // e1 f1 g1 h1 ...
+ // e2 f2 g2 h2 ...
+ // e3 f3 g3 h3 ...
+ // i0 j0 k0 l0 ...
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // m0 n0 o0 p0 ...
+ // m1 n1 o1 p1 ...
+ // m2 n2 o2 p2 ...
+ // m3 n3 o3 p3 ...
+ a = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ b = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ c = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ d = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ e = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ f = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ g = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ h = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ i = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ j = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ k = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ l = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ m = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ n = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ o = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+ p = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // a4 b4 c4 d4 ...
+ // a5 b5 c5 d5 ...
+ // a6 b6 c6 d6 ...
+ // a7 b7 c7 d7 ...
+ // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // i4 j4 k4 l4 ...
+ // i5 j5 k5 l5 ...
+ // i6 j6 k6 l6 ...
+ // i7 j7 k7 l7 ...
+ ta = _mm512_shuffle_f32x4(a, e, 0x88);
+ tb = _mm512_shuffle_f32x4(b, f, 0x88);
+ tc = _mm512_shuffle_f32x4(c, g, 0x88);
+ td = _mm512_shuffle_f32x4(d, h, 0x88);
+ te = _mm512_shuffle_f32x4(a, e, 0xdd);
+ tf = _mm512_shuffle_f32x4(b, f, 0xdd);
+ tg = _mm512_shuffle_f32x4(c, g, 0xdd);
+ th = _mm512_shuffle_f32x4(d, h, 0xdd);
+ ti = _mm512_shuffle_f32x4(i, m, 0x88);
+ tj = _mm512_shuffle_f32x4(j, n, 0x88);
+ tk = _mm512_shuffle_f32x4(k, o, 0x88);
+ tl = _mm512_shuffle_f32x4(l, p, 0x88);
+ tm = _mm512_shuffle_f32x4(i, m, 0xdd);
+ tn = _mm512_shuffle_f32x4(j, n, 0xdd);
+ to = _mm512_shuffle_f32x4(k, o, 0xdd);
+ tq = _mm512_shuffle_f32x4(l, p, 0xdd);
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 ... o0
+ // a1 b1 c1 d1 ... o1
+ // a2 b2 c2 d2 ... o2
+ // a3 b3 c3 d3 ... o3
+ // a4 ...
+ // a5 ...
+ // a6 ...
+ // a7 ...
+ // a8 ...
+ // a9 ...
+ // a10 ...
+ // a11 ...
+ // a12 ...
+ // a13 ...
+ // a14 ...
+ // a15 b15 c15 d15 ... o15
+ a = _mm512_shuffle_f32x4(ta, ti, 0x88);
+ b = _mm512_shuffle_f32x4(tb, tj, 0x88);
+ c = _mm512_shuffle_f32x4(tc, tk, 0x88);
+ d = _mm512_shuffle_f32x4(td, tl, 0x88);
+ e = _mm512_shuffle_f32x4(te, tm, 0x88);
+ f = _mm512_shuffle_f32x4(tf, tn, 0x88);
+ g = _mm512_shuffle_f32x4(tg, to, 0x88);
+ h = _mm512_shuffle_f32x4(th, tq, 0x88);
+ i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
+ j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
+ k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
+ l = _mm512_shuffle_f32x4(td, tl, 0xdd);
+ m = _mm512_shuffle_f32x4(te, tm, 0xdd);
+ n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
+ o = _mm512_shuffle_f32x4(tg, to, 0xdd);
+ p = _mm512_shuffle_f32x4(th, tq, 0xdd);
+
+ // store from registers to dst
+ _mm512_storeu_ps(&dst[0 * ld_dst], a);
+ _mm512_storeu_ps(&dst[1 * ld_dst], b);
+ _mm512_storeu_ps(&dst[2 * ld_dst], c);
+ _mm512_storeu_ps(&dst[3 * ld_dst], d);
+ _mm512_storeu_ps(&dst[4 * ld_dst], e);
+ _mm512_storeu_ps(&dst[5 * ld_dst], f);
+ _mm512_storeu_ps(&dst[6 * ld_dst], g);
+ _mm512_storeu_ps(&dst[7 * ld_dst], h);
+ _mm512_storeu_ps(&dst[8 * ld_dst], i);
+ _mm512_storeu_ps(&dst[9 * ld_dst], j);
+ _mm512_storeu_ps(&dst[10 * ld_dst], k);
+ _mm512_storeu_ps(&dst[11 * ld_dst], l);
+ _mm512_storeu_ps(&dst[12 * ld_dst], m);
+ _mm512_storeu_ps(&dst[13 * ld_dst], n);
+ _mm512_storeu_ps(&dst[14 * ld_dst], o);
+ _mm512_storeu_ps(&dst[15 * ld_dst], p);
+}
+
+void transpose_16x16(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 16 <= M; ib += 16) {
+ for (jb = 0; jb + 16 <= N; jb += 16) {
+ transpose_kernel_16x16_avx512(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+} // namespace fbgemm2
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc
new file mode 100644
index 0000000..8e36c85
--- /dev/null
+++ b/src/codegen_fp16fp32.cc
@@ -0,0 +1,387 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <assert.h>
+#include <cpuid.h>
+#include <stdlib.h>
+#include <string.h>
+#include <algorithm>
+#include <array>
+#include <fstream>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <string>
+
+using namespace std;
+
+void addi(ofstream& of, string i, bool disable = false) {
+ if (disable == false)
+ of << "\"" + i + "\\t\\n\"" + "\n";
+}
+
+struct ISA {
+ unsigned avx; // 1, 2 or 3
+ string name;
+ vector<vector<unsigned>> shapes;
+};
+
+int main() {
+ bool iaca = false;
+ bool disable = false;
+
+ bool fixedA = true, fixedB = true, fixedC = true;
+
+ int eax, ebx, ecx, edx;
+ __cpuid(1 /* ecx = vendor string */, eax, ebx, ecx, edx);
+ printf("FC16 is %s supported\n", ((ecx & bit_F16C) ? " " : "not"));
+
+ string comma = ",";
+
+ vector<ISA> isa = {
+ // {1, "AVX", {{4, 1, 0}, {4, 2, 0}, {4, 3, 0}, {3, 1, 0}, {3, 2, 0}, {3,
+ // 3, 0}}},
+ { 2, "AVX2",
+ { { 1, 1, 0 },
+ { 2, 1, 0 },
+ { 3, 1, 0 },
+ { 4, 1, 0 },
+ { 5, 1, 0 },
+ { 6, 1, 0 },
+ { 7, 1, 0 },
+ { 8, 1, 0 },
+ { 9, 1, 0 },
+ { 10, 1, 0 },
+ { 11, 1, 0 },
+ { 12, 1, 0 },
+ { 13, 1, 0 },
+ { 14, 1, 0 },
+ }
+ }
+ };
+
+ // open all files
+ ofstream srcfile;
+ srcfile.open("FbgemmFP16UKernels.cc");
+ srcfile << "#include \"FbgemmFP16UKernels.h\"\n";
+ if (iaca)
+ srcfile << "#include \"iacaMarks.h\"\n";
+
+ ofstream hdrfile;
+ hdrfile.open("FbgemmFP16UKernels.h");
+
+ hdrfile << "#ifndef FBGEMM_UKERNELS\n";
+ hdrfile << "#define FBGEMM_UKERNELS\n";
+ hdrfile << "#include <cstdint>\n";
+ hdrfile << "#include <tuple>\n";
+ hdrfile << "#include <vector>\n";
+ hdrfile << "#include \"fbgemm/Types.h\"\n";
+ hdrfile << "using fp16 = fbgemm2::float16;\n";
+ hdrfile << "using fp32 = float;\n";
+ hdrfile << "struct GemmParams {uint64_t k; float *A; const fp16 *B;\n"
+ "float *beta; uint64_t accum; float *C; uint64_t ldc;\n"
+ "uint64_t b_block_cols; uint64_t b_block_size;};\n";
+
+ std::map<string, string> fptr_typedef;
+ fptr_typedef["fp16"] = "";
+ fptr_typedef["fp32"] = "";
+
+ unsigned labelId = 0;
+#if 1
+ for (auto fixedA : {false})
+ for (auto fixedB : {false})
+ for (auto fixedC : {false})
+#else
+ for (auto fixedA : {true})
+ for (auto fixedB : {true})
+ for (auto fixedC : {true})
+#endif
+ for (auto s : isa) {
+ vector<vector<unsigned>>& ukernel_shape = s.shapes;
+
+ vector<string> funcname(ukernel_shape.size()),
+ fheader(ukernel_shape.size());
+ string fargs;
+
+ for (auto fp16 : {true}) {
+ string B_type = ((fp16) ? "fp16" : "fp32");
+ string prefix = s.name + /*"_" + B_type */ + "_" + "fA" +
+ to_string(fixedA) + "fB" + to_string(fixedB) + "fC" +
+ to_string(fixedC);
+ cout << "Generating code for " << s.name << " " << B_type << "\n";
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ printf(
+ "shape: %d x %d * 32\n",
+ ukernel_shape[k][0],
+ ukernel_shape[k][1]);
+
+ string p1 = "GemmParams *gp";
+
+ funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) +
+ "x" + to_string(ukernel_shape[k][1]) + "_";
+ funcname[k] += prefix;
+
+ fargs = "(" + p1 + ")";
+
+ fheader[k] =
+ "void __attribute__ ((noinline)) " + funcname[k] + fargs;
+ srcfile << fheader[k] << "\n";
+ srcfile << "{\n";
+
+ unsigned last_free_ymmreg = 0;
+ // produce register block of C
+ vector<vector<string>> vCtile(ukernel_shape[k][0]);
+ for (auto r = 0; r < ukernel_shape[k][0]; r++)
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vCtile[r].push_back("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+ assert(last_free_ymmreg <= 14);
+
+ string vAtmp = "ymm" + to_string(last_free_ymmreg++);
+ // produce register block of B col
+ assert(ukernel_shape[k][1] == 1);
+ vector<string> vBcol(ukernel_shape[k][1]);
+
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vBcol[c] = ("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+
+ assert(last_free_ymmreg <= 16);
+
+ srcfile << "asm volatile\n";
+ srcfile << "(\n";
+
+ srcfile << "#if !defined(__clang__)" << "\n";
+ addi(srcfile, "mov r14, %[gp]");
+ srcfile << "#else\n";
+ addi(srcfile, "mov %[gp], %%r14");
+ addi(srcfile, ".intel_syntax noprefix");
+ srcfile << "#endif\n";
+
+ srcfile << "\n// Copy parameters\n";
+ srcfile << "// k\n";
+ addi(srcfile, "mov r8, [r14 + 0]");
+ srcfile << "// A\n";
+ addi(srcfile, "mov r9, [r14 + 8]");
+ srcfile << "// B\n";
+ addi(srcfile, "mov r10, [r14 + 16]");
+ srcfile << "// beta\n";
+ addi(srcfile, "mov r15, [r14 + 24]");
+ srcfile << "// accum\n";
+ addi(srcfile, "mov rdx, [r14 + 32]");
+ srcfile << "// C\n";
+ addi(srcfile, "mov r12, [r14 + 40]");
+ srcfile << "// ldc\n";
+ addi(srcfile, "mov r13, [r14 + 48]");
+ srcfile << "// b_block_cols\n";
+ addi(srcfile, "mov rdi, [r14 + 56]");
+ srcfile << "// b_block_size\n";
+ addi(srcfile, "mov rsi, [r14 + 64]");
+ srcfile << "// Make copies of A and C\n";
+ addi(srcfile, "mov rax, r9");
+ addi(srcfile, "mov rcx, r12");
+ srcfile << "\n\n";
+
+ addi(srcfile, "mov rbx, 0");
+
+ string exitlabel = "L_exit%=";
+ string label2 = "loop_outter%=";
+ addi(srcfile, label2 + ":");
+ addi(srcfile, "mov r14, 0");
+
+ // set all vCtile regs to zeros
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
+ vCtile[r][c]);
+ }
+ }
+
+ // start marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 111");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ srcfile << "\n";
+
+ if (ukernel_shape[k][0] <= 13) {
+ addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]");
+ addi(srcfile, "mov r11, 16");
+ } else {
+ addi(srcfile, "mov r11, 0");
+ }
+
+ srcfile << "\n";
+ string label = "loop_inner%=";
+ addi(srcfile, label + ":");
+ srcfile << "\n";
+
+ if (ukernel_shape[k][0] <= 13) {
+ auto a_offset = 0, unroll_factor = 2;
+ for (auto u = 0; u < unroll_factor; u++) {
+ string breg = (u == 0) ? "ymm14" : "ymm15";
+ string breg_rev = (u == 0) ? "ymm15" : "ymm14";
+
+ addi(srcfile, "vcvtph2ps " + breg +
+ ",XMMWORD PTR [r10 + r11 + " +
+ to_string(u * 16) + "]");
+ addi(srcfile, "inc r14");
+ for (auto r = 0; r < vCtile.size(); r++) {
+ addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(a_offset) + "]");
+ addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," +
+ breg_rev + "," + vAtmp);
+ if (u == 1 && r == vCtile.size() / 2)
+ addi(srcfile, "add r11, 32");
+ a_offset += 4;
+ }
+ if (u < unroll_factor - 1) {
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jge " + exitlabel);
+ }
+ }
+
+ addi(srcfile, "add r9," + to_string(a_offset));
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+
+ srcfile << "\n";
+
+ addi(srcfile, exitlabel + ":");
+ } else {
+ addi(srcfile,
+ "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]");
+ for (auto r = 0; r < vCtile.size(); r++) {
+ addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(4 * r) + "]");
+ addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] +
+ "," + vAtmp);
+ }
+
+ addi(srcfile, "add r9," + to_string(4 * ukernel_shape[k][0]),
+ fixedA); // move A ptr
+ addi(srcfile, "add r11, 16");
+
+ addi(srcfile, "inc r14");
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+ }
+
+ addi(srcfile, "add r10, rsi");
+ srcfile << "\n";
+
+ // end marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 222");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+
+ addi(srcfile, "cmp rdx, 1");
+ addi(srcfile, "je L_accum%=");
+ srcfile << "// Dump C\n";
+
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(srcfile, "vmovups YMMWORD PTR [r12 + " +
+ to_string(32 * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+ addi(srcfile, "jmp L_done%=");
+
+ srcfile << "\n\n";
+ addi(srcfile, "L_accum%=:");
+ srcfile << "// Dump C with accumulate\n";
+
+ string r_spare = (s.avx == 1) ? "ymm14" : "ymm15";
+ addi(srcfile,
+ "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
+ fixedC);
+ // store out C
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ switch (s.avx) {
+ case 1:
+ addi(srcfile,
+ string("vmulps ymm15, ") + r_spare + comma +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ addi(srcfile, "vaddps " + vCtile[r][c] + "," +
+ vCtile[r][c] + "," + "ymm15",
+ fixedC);
+ break;
+ case 2:
+ addi(srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ break;
+ default:
+ assert(0);
+ }
+ addi(srcfile, "vmovups YMMWORD PTR [r12 + " +
+ to_string(32 * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+
+ srcfile << "\n";
+ addi(srcfile, "L_done%=:");
+
+ srcfile << "\n// next outer iteration\n";
+ // C
+ addi(srcfile, "add rcx, " + to_string(32 * ukernel_shape[k][1]),
+ fixedC);
+ addi(srcfile, "mov r12, rcx", fixedC);
+ // A
+ addi(srcfile, "mov r9, rax");
+
+ addi(srcfile, "inc rbx");
+ addi(srcfile, "cmp rbx, rdi");
+ addi(srcfile, "jl " + label2);
+
+ // output
+ srcfile << ":\n";
+ // input
+ srcfile << ":\n";
+ srcfile << "[gp] \"rm\" (gp)\n";
+
+ // clobbered
+ srcfile
+ << (string) ": \"r8\", \"r9\", \"r10\", \"r11\", \"r15\", " +
+ (string) " \"r13\", \"r14\",\n" +
+ (string) "\"rax\", \"rcx\", "
+ "\"rdx\", \"rsi\", \"rdi\", \"rbx\", "
+ "\"r12\", \"memory\"" +
+ (string) "\n";
+ srcfile << ");\n";
+ srcfile << "}\n";
+ }
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ hdrfile << fheader[k] << ";\n";
+ }
+
+ fptr_typedef[B_type] =
+ "typedef void (* funcptr_" + B_type + ") " + fargs;
+ }
+ }
+
+ srcfile.close();
+ hdrfile << fptr_typedef["fp16"] << ";\n";
+ hdrfile << fptr_typedef["fp32"] << ";\n";
+ hdrfile << "#endif\n";
+ hdrfile.close();
+}
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
new file mode 100644
index 0000000..9005dab
--- /dev/null
+++ b/test/CMakeLists.txt
@@ -0,0 +1,47 @@
+cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+
+if(FBGEMM_BUILD_TESTS AND NOT TARGET gtest)
+ #Download Googletest framework from github if
+ #GOOGLETEST_SOURCE_DIR is not specified.
+ if(NOT DEFINED GOOGLETEST_SOURCE_DIR)
+ message(STATUS "Downloading googletest to
+ ${FBGEMM_THIRDPARTY_DIR}/googletest
+ (define GOOGLETEST_SOURCE_DIR to avoid it)")
+ configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadGTEST.cmake"
+ "${FBGEMM_BINARY_DIR}/googletest-download/CMakeLists.txt")
+ execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/googletest-download")
+ execute_process(COMMAND "${CMAKE_COMMAND}" --build .
+ WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/googletest-download")
+ set(GOOGLETEST_SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/googletest" CACHE STRING
+ "googletest source directory")
+ endif()
+
+ #build Googletest framework
+ add_subdirectory("${GOOGLETEST_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/googletest")
+endif()
+
+macro(add_gtest TESTNAME)
+ add_executable(${TESTNAME} ${ARGN}
+ ../bench/BenchUtils.cc QuantizationHelpers.cc TestUtils.cc)
+ set_target_properties(${TESTNAME} PROPERTIES
+ CXX_STANDARD 11
+ CXX_EXTENSIONS NO)
+ target_compile_options(${TESTNAME} PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel")
+ target_link_libraries(${TESTNAME} gtest gmock gtest_main fbgemm)
+ add_dependencies(${TESTNAME} gtest fbgemm)
+ add_test(${TESTNAME} ${TESTNAME})
+ set_target_properties(${TESTNAME} PROPERTIES FOLDER test)
+endmacro()
+
+
+file(GLOB TEST_LIST "*Test.cc")
+
+foreach(TEST_FILE ${TEST_LIST})
+ get_filename_component(TEST_NAME "${TEST_FILE}" NAME_WE)
+ get_filename_component(TEST_FILE_ONLY "${TEST_FILE}" NAME)
+ add_gtest("${TEST_NAME}"
+ "${TEST_FILE_ONLY}")
+endforeach()
+
diff --git a/test/FP16Test.cc b/test/FP16Test.cc
new file mode 100644
index 0000000..c346049
--- /dev/null
+++ b/test/FP16Test.cc
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <random>
+
+#include <gtest/gtest.h>
+
+#include "fbgemm/FbgemmFP16.h"
+#include "src/RefImplementations.h"
+#include "bench/BenchUtils.h"
+#include "TestUtils.h"
+
+#ifdef USE_IACA
+#include "iacaMarks.h"
+#endif
+
+using namespace std;
+using namespace fbgemm2;
+
+namespace {
+ // The template parameter is transpose of A and B
+ class FBGemmFP16Test :
+ public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {};
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ FBGemmFP16Test,
+ ::testing::Values(
+ pair<matrix_op_t, matrix_op_t>(
+ matrix_op_t::NoTranspose, matrix_op_t::NoTranspose),
+ pair<matrix_op_t, matrix_op_t>(
+ matrix_op_t::NoTranspose, matrix_op_t::Transpose)/*,
+ pair<matrix_op_t, matrix_op_t>(
+ matrix_op_t::Transpose, matrix_op_t::NoTranspose),
+ pair<matrix_op_t, matrix_op_t>(
+ matrix_op_t::Transpose, matrix_op_t::Transpose)*/));
+
+TEST_P(FBGemmFP16Test, Test) {
+ vector<vector<int> > shapes;
+ random_device r;
+ default_random_engine generator(r());
+ uniform_int_distribution<int> dm(1,100);
+ uniform_int_distribution<int> dnk(1,1024);
+ for (int i = 0; i < 10; i++) {
+ int m = dm(generator);
+ int n = dnk(generator);
+ int k = dnk(generator);
+ shapes.push_back({m, n, k});
+ if (m > 10) {
+ shapes.push_back({(m / 10) * 10, n, k});
+ }
+ }
+
+ float alpha = 1.f, beta = 0.f;
+ matrix_op_t atrans, btrans;
+ tie(atrans, btrans) = GetParam();
+
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+ int k = s[2];
+
+ cerr << "m = " << m << " n = " << n << " k = " << k;
+ if (atrans == matrix_op_t::Transpose) {
+ cerr << " A_transposed";
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ cerr << " B_transposed";
+ }
+ cerr << endl;
+
+ aligned_vector<float> A(m * k, 0.f);
+ aligned_vector<float> B(k * n, 0.f);
+ aligned_vector<float> C(m * n, 0.f);
+
+ // initialize with small numbers
+ randFill(A, 0, 4);
+ randFill(B, 0, 4);
+ randFill(C, 0, 4);
+
+ aligned_vector<float> A_ref, B_ref, C_ref;
+ A_ref = A;
+ B_ref = B;
+ C_ref = C;
+
+ if (atrans == matrix_op_t::Transpose) {
+ transpose_matrix(A_ref.data(), k, m);
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(B_ref.data(), n, k);
+ }
+
+ // Gold via reference sgemm
+ matmul_fp_ref(
+ m,
+ n,
+ k,
+ k,
+ n,
+ n,
+ A_ref.data(),
+ B_ref.data(),
+ C_ref.data());
+
+ // fbgemm fp16
+ PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
+ cblas_gemm_compute(atrans, m, A.data(), Bp, beta, C.data());
+
+ // correctness check
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ float expected = C_ref[i * n + j];
+ float actual = C[i * n + j];
+ EXPECT_EQ(expected, actual) <<
+ "GEMM results differ at (" << i << ", " << j <<
+ "). ref " << expected << " FBGemm " << actual;
+ }
+ }
+ }
+}
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
new file mode 100644
index 0000000..cfde880
--- /dev/null
+++ b/test/I8DepthwiseTest.cc
@@ -0,0 +1,448 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "I8DepthwiseTest.h"
+
+#include <cmath>
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+#include "bench/AlignedVec.h"
+#include "src/FbgemmI8Depthwise.h"
+#include "src/RefImplementations.h"
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+
+using namespace std;
+
+namespace fbgemm2
+{
+
+// From Xray OCR
+static vector<vector<int>> shapes = {
+ // N, K, H_in, W_in, stride
+ { 1, 272, 47, 125, 1, },
+// { 1, 272, 64, 125, 1, },
+// { 1, 272, 66, 125, 1, },
+// { 1, 272, 67, 100, 1, },
+// { 1, 272, 75, 75, 1, },
+ { 1, 272, 75, 76, 1, },
+// { 1, 272, 75, 100, 1, },
+// { 1, 272, 94, 75, 1, },
+// { 1, 272, 109, 75, 1, },
+ { 1, 544, 24, 63, 1, },
+// { 1, 544, 33, 63, 1, },
+// { 1, 544, 34, 50, 1, },
+// { 1, 544, 36, 63, 1, },
+// { 1, 544, 38, 38, 1, },
+// { 1, 544, 38, 40, 1, },
+ { 1, 544, 47, 38, 1, },
+ { 1, 1088, 7, 7, 1, },
+ { 51, 1088, 7, 7, 1, },
+// { 100, 1088, 7, 7, 1, },
+
+ { 1, 248, 93, 250, 2, },
+// { 1, 248, 128, 250, 2, },
+// { 1, 248, 133, 200, 2, },
+// { 1, 248, 150, 150, 2, },
+ { 1, 248, 150, 151, 2, },
+// { 1, 248, 150, 158, 2, },
+// { 1, 248, 188, 150, 2, },
+// { 1, 248, 225, 150, 2, },
+ { 1, 272, 47, 125, 2, },
+// { 1, 272, 64, 125, 2, },
+// { 1, 272, 66, 125, 2, },
+// { 1, 272, 67, 100, 2, },
+// { 1, 272, 75, 75, 2, },
+// { 1, 272, 75, 76, 2, },
+ { 1, 272, 94, 75, 2, },
+ { 1, 544, 14, 14, 2, },
+ { 51, 544, 14, 14, 2, },
+// { 100, 544, 14, 14, 2, },
+
+ { 1, 8, 4, 4, 1, },
+};
+
+TEST(FBGemmDepthWiseTest, Test3x3) {
+ for (auto shape : shapes) {
+ int N = shape[0];
+ int K = shape[1];
+ int H = shape[2];
+ int W = shape[3];
+ int stride_h = shape[4];
+ int stride_w = stride_h;
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * H * W * K);
+ aligned_vector<int8_t> B(K * R * S);
+ aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * K), C(C_ref.size());
+
+ randFill(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ randFill(B, -16, 16);
+ int32_t B_zero_point = 5;
+
+ depthwise_3x3_pad_1_ref(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
+ int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+
+ float C_multiplier = 255. / (maximum - minimum);
+
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+ int32_t C_zero_point = 5;
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3_pad_1_ref(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ B.data(),
+ C_multiplier,
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3ConvMatrix Bp(K, B.data());
+
+ depthwise_3x3_pad_1(
+ N, H, W, K, stride_h, stride_w, A_zero_point, A.data(), Bp, C.data());
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected = C_ref[((n * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual = C[((n * H_OUT + h) * W_OUT + w) * K + k];
+ EXPECT_EQ(expected, actual) <<
+ "Depthwise 3x3 results differ at (" << n << ", " <<
+ h << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+
+ depthwise_3x3_pad_1(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_uint8.data(),
+ col_offsets.data(),
+ bias.data(),
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected =
+ C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + k];
+ EXPECT_EQ(expected, actual) <<
+ "Depthwise 3x3 results differ at (" << n << ", " <<
+ h << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+ } // for each shape
+} // Test3x3
+
+TEST(FBGemmDepthWiseTest, Test3x3x3) {
+ for (auto shape : shapes_3d) {
+ int N = shape[0];
+ int K = shape[1];
+ int T = shape[2];
+ int H = shape[3];
+ int W = shape[4];
+ int stride_t = shape[5];
+ int stride_h = stride_t;
+ int stride_w = stride_t;
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * T * H * W * K);
+ aligned_vector<int8_t> B(K * K_T * K_H * K_W);
+ aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
+ C(C_ref.size());
+
+ randFill(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ randFill(B, -16, 16);
+ int32_t B_zero_point = 5;
+
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
+ int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+
+ float C_multiplier = 255. / (maximum - minimum);
+
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+ int32_t C_zero_point = 5;
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point,
+ B.data(),
+ C_multiplier,
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3x3ConvMatrix Bp(K, B.data());
+
+ depthwise_3x3x3_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ Bp,
+ C.data());
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected =
+ C_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual =
+ C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ ASSERT_EQ(expected, actual)
+ << "Depthwise 3x3 results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << k << ") "
+ << shape[0] << " " << shape[1] << " " << shape[2] << " "
+ << shape[3] << " " << shape[4] << " " << shape[5];
+ }
+ } // w
+ } // h
+ } // t
+ } // n
+
+ depthwise_3x3x3_pad_1(
+ N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A.data(),
+ B_zero_point, Bp, C_multiplier, C_zero_point,
+ C_uint8.data(), col_offsets.data(), bias.data(),
+ false /* fuse_relu */, 0, 1);
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected = C_uint8_ref
+ [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual =
+ C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ EXPECT_EQ(expected, actual)
+ << "Depthwise 3x3 results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << k << ").";
+ }
+ } // w
+ } // h
+ } // t
+ } // n
+ } // for each shape
+} // Test3x3x3
+
+TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
+ for (auto shape : shapes) {
+ int N = shape[0];
+ int K = shape[1];
+ int H = shape[2];
+ int W = shape[3];
+ int stride_h = shape[4];
+ int stride_w = stride_h;
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * H * W * K);
+ aligned_vector<int8_t> B(K * R * S);
+ int32_t C_num_rows = N * H_OUT * W_OUT;
+ aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+
+ randFill(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ // Each row of G has a different range to really test per-channel
+ // quantization.
+ vector<int32_t> B_zero_point(K);
+ for (auto k = 0; k < K; ++k) {
+ aligned_vector<int8_t> Bk(R * S);
+ randFill(Bk, -16 + k, 16 + k);
+ copy(Bk.begin(), Bk.end(), B.begin() + k * R * S);
+
+ B_zero_point[k] = 5 + k;
+ }
+
+ depthwise_3x3_pad_1_ref(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ aligned_vector<int32_t> C_ref_transpose(C_ref);
+ transpose_matrix(C_ref.data(), C_num_rows, K);
+ vector<float> C_multiplier(K);
+ for (auto k = 0; k < K; ++k) {
+ auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
+ auto C_ref_k_end = C_ref_k_begin + C_num_rows;
+ int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
+ int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
+ C_multiplier[k] = 255. / (maximum - minimum);
+ cerr << "k " << k << " minimum " << minimum << " maximum " << maximum
+ << " multiplier " << C_multiplier[k] << endl;
+ }
+ int32_t C_zero_point = 5;
+
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3_per_channel_quantization_pad_1_ref(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point.data(),
+ B.data(),
+ C_multiplier.data(),
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3ConvMatrix Bp(K, B.data());
+
+ depthwise_3x3_per_channel_quantization_pad_1(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point.data(),
+ Bp,
+ C_multiplier.data(),
+ C_zero_point,
+ C_uint8.data(),
+ col_offsets.data(),
+ bias.data(),
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected =
+ C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + k];
+ EXPECT_EQ(expected, actual) <<
+ "Depthwise 3x3 results differ at (" << n << ", " <<
+ h << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+ } // for each shape
+} // Test3x3PerChannelQuantization
+
+} // namespace fbgemm2
diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h
new file mode 100644
index 0000000..38dea6f
--- /dev/null
+++ b/test/I8DepthwiseTest.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <vector>
+
+namespace fbgemm2
+{
+
+// From ResNeXt-3D-101
+static std::vector<std::vector<int>> shapes_3d = {
+ // N, K, T_in, H_in, W_in, stride
+ { 1, 64, 32, 56, 56, 1, },
+ { 1, 128, 16, 28, 28, 1, },
+ { 1, 256, 8, 14, 14, 1, },
+ { 1, 512, 4, 7, 7, 1, },
+
+ { 1, 128, 32, 56, 56, 2, },
+ { 1, 256, 16, 28, 28, 2, },
+ { 1, 512, 8, 14, 14, 2, },
+
+ { 5, 64, 32, 56, 56, 1, },
+ { 5, 128, 16, 28, 28, 1, },
+ { 5, 256, 8, 14, 14, 1, },
+ { 5, 512, 4, 7, 7, 1, },
+
+ { 5, 128, 32, 56, 56, 2, },
+ { 5, 256, 16, 28, 28, 2, },
+ { 5, 512, 8, 14, 14, 2, },
+
+ { 1, 8, 4, 4, 4, 1, },
+};
+
+} // namespace fbgemm2
diff --git a/test/I8SpmdmTest.cc b/test/I8SpmdmTest.cc
new file mode 100644
index 0000000..c74c98a
--- /dev/null
+++ b/test/I8SpmdmTest.cc
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdlib>
+#include <iostream>
+#include <numeric>
+#include <random>
+
+#include <gtest/gtest.h>
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "fbgemm/FbgemmI8Spmdm.h"
+#include "src/RefImplementations.h"
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+std::vector<float> densities{0.0001f, 0.001f, 0.01f, 0.1f, 1.0f};
+
+namespace {
+class fbgemmSPMDMTest
+ : public testing::TestWithParam<std::tuple<float, bool, bool>> {};
+} // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ Instance0,
+ fbgemmSPMDMTest,
+ ::testing::Combine(
+ ::testing::ValuesIn(densities),
+ ::testing::Bool(),
+ ::testing::Bool()));
+
+TEST_P(fbgemmSPMDMTest, TestsSpMDM) {
+ const vector<array<int, 3>> shapes = {
+ // M, N, K
+ { 1024, 1024, 1024 },
+ { 511, 512, 512 },
+ { 111, 111, 111 },
+ { 14*14*2, 4, 2 },
+ };
+
+ float density;
+ bool accumulation, test_ld;
+ tie(density, accumulation, test_ld) = GetParam();
+
+ for (const auto& shape : shapes) {
+ int M = shape[0];
+ int N = shape[1];
+ int K = shape[2];
+ int N_adjusted = N;
+ int K_adjusted = K;
+ if (test_ld) {
+ // When test_ld is true, we multiply with the bottom-right quadrant of B
+ N_adjusted = std::max(N / 2, 1);
+ K_adjusted = std::max(K / 2, 1);
+ }
+
+ aligned_vector<uint8_t> A(M * K);
+ randFill(A, 0, 255);
+
+ CompressedSparseColumn B_csc(K_adjusted, N_adjusted);
+ vector<int32_t> C(M * N);
+ vector<int32_t> C_ref(C.size());
+
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ C_ref[i * N + j] = i + j;
+ }
+ }
+
+ // deterministic random number
+ default_random_engine eng;
+ binomial_distribution<> per_col_nnz_dist(K_adjusted, density);
+ uniform_int_distribution<> value_dist(
+ numeric_limits<int8_t>::min() / 2,
+ numeric_limits<int8_t>::max() / 2);
+
+ vector<int> row_indices(K_adjusted);
+
+ int total_nnz = 0;
+ for (int j = 0; j < N_adjusted; ++j) {
+ B_csc.ColPtr()[j] = total_nnz;
+
+ int nnz_of_j = per_col_nnz_dist(eng);
+ total_nnz += nnz_of_j;
+
+ iota(row_indices.begin(), row_indices.end(), 0);
+ shuffle(row_indices.begin(), row_indices.end(), eng);
+ sort(row_indices.begin(), row_indices.begin() + nnz_of_j);
+
+ for (int k = 0; k < nnz_of_j; ++k) {
+ B_csc.RowIdx().push_back(row_indices[k]);
+ B_csc.Values().push_back(value_dist(eng));
+ }
+ }
+ B_csc.ColPtr()[N_adjusted] = total_nnz;
+
+ spmdm_ref(
+ M,
+ A.data() + (test_ld ? K_adjusted : 0),
+ K,
+ B_csc,
+ accumulation,
+ C_ref.data() + (test_ld ? N_adjusted : 0),
+ N);
+
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ if (accumulation) {
+ C[i * N + j] = i + j;
+ } else {
+ C[i * N + j] = i + j + 1;
+ }
+ }
+ }
+
+#pragma omp parallel
+ {
+#ifdef _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ int i_per_thread = (M + num_threads - 1) / num_threads;
+ int i_begin = std::min(tid * i_per_thread, M);
+ int i_end = std::min(i_begin + i_per_thread, M);
+
+ block_type_t block = {i_begin, i_end - i_begin, 0, N_adjusted};
+ B_csc.SpMDM(
+ block,
+ A.data() + (test_ld ? K_adjusted : 0),
+ K,
+ accumulation,
+ C.data() + i_begin * N + (test_ld ? N_adjusted : 0),
+ N);
+ }
+
+ compare_validate_buffers(
+ C_ref.data() + (test_ld ? N_adjusted : 0),
+ C.data() + (test_ld ? N_adjusted : 0),
+ M,
+ N_adjusted,
+ N,
+ static_cast<int32_t>(0));
+ } // for each shape
+}
diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc
new file mode 100644
index 0000000..28e1114
--- /dev/null
+++ b/test/PackedRequantizeAcc16Test.cc
@@ -0,0 +1,535 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include <gtest/gtest.h>
+
+#include "fbgemm/Fbgemm.h"
+#include "bench/BenchUtils.h"
+#include "src/RefImplementations.h"
+#include "QuantizationHelpers.h"
+#include "TestUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+std::vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose};
+
+namespace {
+class fbgemmu8s8acc16test : public testing::TestWithParam<
+ std::tuple<matrix_op_t, matrix_op_t, bool>> {};
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmu8s8acc16test,
+ ::testing::Combine(
+ ::testing::Values(matrix_op_t::NoTranspose),
+ ::testing::ValuesIn(transposeVals),
+ ::testing::Bool()));
+
+/**
+ * @brief Shapes for unit test.
+ */
+static vector<vector<int>> GetShapes_() {
+ // NMT
+ vector<vector<int>> shapes = {
+ // {M, N, K}
+ {1, 128, 512},
+ {1, 1024, 256},
+ {1, 2048, 512},
+ {1, 2048, 513},
+ {1, 2048, 514},
+
+ {6, 512, 512},
+ {6, 2048, 512},
+ {6, 256, 1024},
+ {6, 1024, 256},
+ {6, 2048, 256},
+ {6, 2048, 257},
+ {6, 2048, 258},
+
+ {102, 1024, 512},
+ {102, 2323, 256},
+ {102, 512, 256},
+ {102, 512, 257},
+ {102, 512, 258},
+ };
+ return shapes;
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit
+ * accumulation. Output processing: requantization -> nothing
+ */
+TEST_P(fbgemmu8s8acc16test, Test) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+ aligned_vector<int8_t> Bint8_ref(k * n, 0);
+ aligned_vector<int32_t> Cint32_local(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_local(m * n, 0);
+
+ randFill(Aint8, 0, 255);
+ int32_t Aint8_zero_point = 43;
+
+ randFill(Bint8_ref, -128, 127);
+
+ for (auto i = 0; i < Bint8.size(); ++i) {
+ Bint8[i] = Bint8_ref[i];
+ }
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ int32_t Bint8_zero_point = -30;
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int k_adjusted = k;
+ int n_adjusted = n;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8_ref.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ float C_multiplier = 0.1234;
+ int32_t C_zero_pt = 5;
+
+ int brow = 256;
+ matmul_u8i8acc16_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ brow,
+ Aint8.data(),
+ Bint8_ref.data(),
+ Cint32_local.data());
+
+ row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data());
+
+ requantize_u8acc32_ref(
+ m,
+ n_adjusted,
+ n,
+ Cint32_local.data(),
+ Cint8_local.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ nullptr);
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t, int16_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false> outputProcObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint8_fb.data(),
+ Cint32_buffer.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+
+ compare_validate_buffers(
+ Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast<uint8_t>(0));
+ }
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit
+ * accumulation. Output processing: spmdm -> requantization -> nothing
+ */
+TEST_P(fbgemmu8s8acc16test, SpMDMTest) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+ aligned_vector<int8_t> Bint8_ref(k * n, 0);
+ aligned_vector<int32_t> Cint32_local(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_local(m * n, 0);
+
+ randFill(Aint8, 0, 255);
+ int32_t Aint8_zero_point = 43;
+
+ randFill(Bint8, -128, 127);
+
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int k_adjusted = k;
+ int n_adjusted = n;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ int32_t Bint8_zero_point = -30;
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ CompressedSparseColumn B_csc(k_adjusted, n_adjusted);
+ float density = 0.001f;
+ // deterministic random number
+ default_random_engine eng;
+ binomial_distribution<> per_col_nnz_dist(k_adjusted, density);
+ uniform_int_distribution<> value_dist(
+ numeric_limits<int8_t>::min() / 2, numeric_limits<int8_t>::max() / 2);
+
+ vector<int> row_indices(k_adjusted);
+ int total_nnz = 0;
+ for (int j = 0; j < n_adjusted; ++j) {
+ B_csc.ColPtr()[j] = total_nnz;
+
+ int nnz_of_j = per_col_nnz_dist(eng);
+ total_nnz += nnz_of_j;
+
+ iota(row_indices.begin(), row_indices.end(), 0);
+ shuffle(row_indices.begin(), row_indices.end(), eng);
+ sort(row_indices.begin(), row_indices.begin() + nnz_of_j);
+
+ for (int kidx = 0; kidx < nnz_of_j; ++kidx) {
+ B_csc.RowIdx().push_back(row_indices[kidx]);
+ // put the current B value
+ B_csc.Values().push_back(Bint8[row_indices[kidx] * n + j]);
+ // make current B value zero
+ Bint8[row_indices[kidx] * n + j] = 0;
+ }
+ }
+ B_csc.ColPtr()[n_adjusted] = total_nnz;
+
+ for (auto i = 0; i < Bint8.size(); ++i) {
+ Bint8_ref[i] = Bint8[i];
+ }
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ float C_multiplier = 0.1234;
+ int32_t C_zero_pt = 5;
+
+ int brow = 256;
+ matmul_u8i8acc16_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ brow,
+ Aint8.data(),
+ Bint8_ref.data(),
+ Cint32_local.data());
+
+ bool accumulation = true;
+ spmdm_ref(m, Aint8.data(), k, B_csc, accumulation, Cint32_local.data(), n);
+
+ row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data());
+
+ requantize_u8acc32_ref(
+ m,
+ n_adjusted,
+ n,
+ Cint32_local.data(),
+ Cint8_local.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ nullptr);
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t, int16_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ // spmdm -> requantization -> nothing
+ // construct an output processing pipeline in reverse order
+ // i.e. last output operation first
+ // Last operation should always be DoNothing with
+ // correct input and output type.
+ DoNothing<> doNothingObj{};
+ // The second last operation is requantization back
+ // to int8
+ ReQuantizeOutput<false> reqObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+ // the top most (first) operation in the output processing
+ // pipeline is spmdm
+ // outType = final output type after fullly processing through pipeline
+ // inType = initial input type at the first call to the whole pipeline
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<false>::outType,
+ int32_t,
+ ReQuantizeOutput<false>>
+ spmdmObj(reqObj, Aint8.data(), k, B_csc);
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n);
+
+ fbgemmPacked(
+ packAN, packedB, Cint8_fb.data(), Cint32_fb.data(), n, spmdmObj, 0, 1);
+
+ compare_validate_buffers(
+ Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast<uint8_t>(0));
+ }
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit
+ * accumulation. Output processing: nothing
+ */
+TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+ aligned_vector<int8_t> Bint8_ref(k * n, 0);
+ aligned_vector<int32_t> Cint32_local(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_local(m * n, 0);
+
+ randFill(Aint8, 0, 255);
+ int32_t Aint8_zero_point = 43;
+
+ randFill(Bint8_ref, -128, 127);
+
+ for (auto i = 0; i < Bint8.size(); ++i) {
+ Bint8[i] = Bint8_ref[i];
+ }
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ int32_t Bint8_zero_point = -30;
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int k_adjusted = k;
+ int n_adjusted = n;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8_ref.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ int brow = 256;
+ matmul_u8i8acc16_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ brow,
+ Aint8.data(),
+ Bint8_ref.data(),
+ Cint32_local.data());
+
+ row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data());
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t, int16_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ // DoNothing<> doNothingObj{};
+ DoNothing<int32_t, int32_t> doNothingObj{};
+ memCopy<> outputProcObj(doNothingObj);
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint32_fb.data(),
+ Cint32_buffer.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+
+ compare_validate_buffers(
+ Cint32_local.data(),
+ Cint32_fb.data(),
+ m,
+ n,
+ n,
+ static_cast<int32_t>(0));
+ }
+}
diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc
new file mode 100644
index 0000000..4b1b5b8
--- /dev/null
+++ b/test/PackedRequantizeTest.cc
@@ -0,0 +1,625 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include <gtest/gtest.h>
+
+#include "fbgemm/Fbgemm.h"
+#include "bench/BenchUtils.h"
+#include "src/RefImplementations.h"
+#include "QuantizationHelpers.h"
+#include "TestUtils.h"
+
+using namespace std;
+using namespace fbgemm2;
+
+std::vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose};
+
+namespace {
+class fbgemmu8s8acc32test : public testing::TestWithParam<
+ std::tuple<matrix_op_t, matrix_op_t, bool>> {};
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmu8s8acc32test,
+ ::testing::Combine(
+ ::testing::Values(matrix_op_t::NoTranspose),
+ ::testing::ValuesIn(transposeVals),
+ ::testing::Bool()));
+
+/**
+ * @brief Shapes for unit test.
+ */
+static vector<vector<int>> GetShapes_() {
+ // NMT
+ vector<vector<int>> shapes = {
+ // {M, N, K}
+ {1, 128, 512},
+ {1, 1024, 256},
+ {1, 2048, 512},
+ {1, 2048, 513},
+ {1, 2048, 514},
+
+ {6, 512, 512},
+ {6, 2048, 512},
+ {6, 256, 1024},
+ {6, 1024, 256},
+ {6, 2048, 256},
+ {6, 2048, 257},
+ {6, 2048, 258},
+
+ {102, 1024, 512},
+ {102, 2323, 256},
+ {102, 512, 256},
+ {102, 512, 257},
+ {102, 512, 258},
+ };
+ return shapes;
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit
+ * accumulation. Output processing: requantization -> nothing
+ */
+TEST_P(fbgemmu8s8acc32test, Test) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ // nxk matrix
+ aligned_vector<int8_t> Bint8(k * n, 0);
+ // kxn matrix
+ aligned_vector<int8_t> Bint8_ref(k * n, 0);
+
+ aligned_vector<int32_t> Cint32_ref(m * n, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<int32_t> Cint32_local(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+ aligned_vector<uint8_t> Cint8_local(m * n, 0);
+
+ randFill(Aint8, 0, 255);
+ int32_t Aint8_zero_point = 43;
+
+ randFill(Bint8_ref, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8_ref.data());
+
+ for (auto i = 0; i < Bint8.size(); ++i) {
+ Bint8[i] = Bint8_ref[i];
+ }
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ int32_t Bint8_zero_point = -30;
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int k_adjusted = k;
+ int n_adjusted = n;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8_ref.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ float C_multiplier = 0.1234;
+ int32_t C_zero_pt = 5;
+
+ matmul_u8i8acc32_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ Aint8.data(),
+ Bint8_ref.data(),
+ Cint32_local.data());
+
+ row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data());
+
+ requantize_u8acc32_ref(
+ m,
+ n_adjusted,
+ n,
+ Cint32_local.data(),
+ Cint8_local.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ nullptr);
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(PackAWithRowOffset<uint8_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false> outputProcObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint8_fb.data(),
+ Cint32_buffer.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_local.data(),
+ // m, n_adjusted, n, "C local");
+ compare_validate_buffers(
+ Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast<uint8_t>(0));
+ }
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit
+ * accumulation. Directly output fp32 matrix C. Output processing:
+ * requantization -> nothing
+ */
+TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<float> Afp32(m * k, 0.0f);
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ aligned_vector<float> Bfp32(k * n, 0.0f);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+
+ aligned_vector<float> Cfp32_ref(m * n, 0.0f);
+ aligned_vector<float> Cfp32_fb(m * n, 0.0f);
+
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<int32_t> Cint32_buffer(m * n, 0);
+
+ randFill(Aint8, 0, 255);
+ int32_t Aint8_zero_point = 43;
+ float Aint8_scale = 0.11;
+ for (auto i = 0; i < Afp32.size(); ++i) {
+ Afp32[i] = Aint8_scale * (Aint8[i] - Aint8_zero_point);
+ }
+
+ randFill(Bint8, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8.data());
+ int32_t Bint8_zero_point = -30;
+ float Bint8_scale = 0.49;
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bfp32[i] = Bint8_scale * (Bint8[i] - Bint8_zero_point);
+ }
+
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int k_adjusted = k;
+ int n_adjusted = n;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ matmul_fp_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ Afp32.data(),
+ Bfp32.data(),
+ Cfp32_ref.data());
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize());
+
+ PackAWithQuantRowOffset<uint8_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Afp32.data(),
+ k,
+ nullptr, /*buffer for packed matrix*/
+ Aint8_scale,
+ Aint8_zero_point,
+ 1, /*groups*/
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<float, float> doNothingObj{};
+ ReQuantizeForFloat<false> outputProcObj(
+ doNothingObj,
+ Aint8_scale,
+ Bint8_scale,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr);
+
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cfp32_fb.data(),
+ (int32_t*)Cfp32_fb.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+
+ float maximum = *max_element(Cfp32_ref.begin(), Cfp32_ref.end());
+ float minimum = *min_element(Cfp32_ref.begin(), Cfp32_ref.end());
+ float atol = (maximum - minimum) / 255 / 1.9;
+
+ compare_validate_buffers(Cfp32_ref.data(), Cfp32_fb.data(), m, n, n, atol);
+ }
+}
+
+/**
+ * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit
+ * accumulation. Output processing: requantization -> nothing. Symmetric: the
+ * zero point is 0.
+ */
+TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<float> Afp32(m * k, 0.0f);
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+
+ aligned_vector<float> Bfp32(k * n, 0.0f);
+ aligned_vector<int8_t> Bint8(k * n, 0);
+
+ aligned_vector<float> Cfp32_ref(m * n, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+
+ randFill(Afp32, 0, 255);
+ for (auto i = 0; i < Afp32.size(); i++) {
+ Aint8[i] = (uint8_t)Afp32[i];
+ }
+
+ // initialize B matrix
+ randFill(Bfp32, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bfp32.data());
+
+ for (auto i = 0; i < Bfp32.size(); ++i) {
+ Bint8[i] = (int8_t)Bfp32[i];
+ }
+
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int m_adjusted = m;
+ int n_adjusted = n;
+ int k_adjusted = k;
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ matmul_fp_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ Afp32.data(),
+ Bfp32.data(),
+ Cfp32_ref.data());
+
+ DoNothing<int32_t, int32_t> doNothingObj{};
+ memCopy<> outputProcObj(doNothingObj);
+ // A zero point and row offset not required
+ PackAMatrix<uint8_t> packAN(
+ matrix_op_t::NoTranspose, m, k_adjusted, Aint8.data(), k);
+
+ // B zero point defaults to 0
+ PackBMatrix<int8_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n);
+
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+
+ // correctness check
+ for (int i = 0; i < m_adjusted; ++i) {
+ for (int j = 0; j < n_adjusted; ++j) {
+ float expected = Cfp32_ref[i * n + j];
+ int32_t actual = Cint32_fb[i * n + j];
+ EXPECT_EQ(expected, actual)
+ << "GEMM results differ at (" << i << ", " << j << "). ref "
+ << expected << " FBGemm " << actual;
+ }
+ }
+ }
+}
+
+/**
+ * @brief Unit test for unt8 matrix A, int8 matrix B, and 32-bit
+ * accumulation. Output processing: requantization with bias -> nothing.
+ * Asymmetric: the zero point is not 0.
+ */
+TEST_P(fbgemmu8s8acc32test, TestAsymmetricQuantizedWithBias) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ bool test_ld;
+ tie(atrans, btrans, test_ld) = GetParam();
+
+ for (auto shape : shapes) {
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k, 0);
+ aligned_vector<uint8_t> Aint8_ref(m * k, 0);
+
+ aligned_vector<int8_t> Bint8(k * n, 0);
+ aligned_vector<int8_t> Bint8_ref(k * n, 0);
+
+ aligned_vector<int32_t> Cint32_fb(m * n, 0);
+ aligned_vector<int32_t> Cint32_ref(m * n, 0);
+
+ aligned_vector<uint8_t> Cint8_fb(m * n, 0);
+ aligned_vector<uint8_t> Cint8_ref(m * n, 0);
+
+ int n_adjusted = n;
+ int k_adjusted = k;
+
+ if (test_ld) {
+ assert(
+ atrans == matrix_op_t::NoTranspose && "This case is not handled yet");
+ k_adjusted = std::max(k / 2, 1);
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // A and B have scale 1, so exactly represented after quantization
+ randFill(Aint8, 0, 255);
+ randFill(Bint8, -128, 127);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8.data());
+
+ for (auto i = 0; i < Bint8.size(); ++i) {
+ Bint8_ref[i] = Bint8[i];
+ }
+
+ for (auto i = 0; i < Aint8.size(); ++i) {
+ Aint8_ref[i] = Aint8[i];
+ }
+
+ int32_t Aint8_zero_point = 55;
+ int32_t Bint8_zero_point = -17;
+
+ // initialize bias
+ aligned_vector<int32_t> bias_int32(n);
+ randFill(bias_int32, -128, 127);
+
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(Bint8.data(), k, n);
+ }
+
+ // computing column offset
+ vector<int32_t> col_offsets;
+ col_offsets.resize(n_adjusted);
+ col_offsets_with_zero_pt_s8acc32_ref(
+ k_adjusted,
+ n_adjusted,
+ n,
+ Bint8_ref.data(),
+ Bint8_zero_point,
+ col_offsets.data());
+
+ matmul_u8i8acc32_ref(
+ m,
+ n_adjusted,
+ k_adjusted,
+ k,
+ n,
+ n,
+ Aint8.data(),
+ Bint8_ref.data(),
+ Cint32_ref.data());
+
+ vector<int32_t> row_offsets;
+ row_offsets.resize(m);
+
+ row_offsets_u8acc32_ref(
+ m, k_adjusted, k, Aint8_ref.data(), row_offsets.data());
+
+ float C_multiplier = 0.1234;
+ int32_t C_zero_pt = 5;
+
+ requantize_u8acc32_ref(
+ m,
+ n_adjusted,
+ n,
+ Cint32_ref.data(),
+ Cint8_ref.data(),
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ row_offsets.data(),
+ col_offsets.data(),
+ bias_int32.data());
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(PackAWithRowOffset<uint8_t>::rowOffsetBufferSize());
+
+ PackAWithRowOffset<uint8_t> packAN(
+ matrix_op_t::NoTranspose,
+ m,
+ k_adjusted,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ PackBMatrix<int8_t> packedBN(
+ btrans,
+ k_adjusted,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k : n,
+ nullptr,
+ 1,
+ Bint8_zero_point);
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false> outputProcObj(
+ doNothingObj,
+ C_multiplier,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point,
+ packAN.getRowOffsetBuffer(),
+ col_offsets.data(),
+ bias_int32.data());
+
+ fbgemmPacked(
+ packAN,
+ packedBN,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ n,
+ outputProcObj,
+ 0,
+ 1);
+
+ compare_validate_buffers(
+ Cint8_fb.data(), Cint8_ref.data(), m, n, n, static_cast<uint8_t>(0));
+ }
+}
diff --git a/test/QuantizationHelpers.cc b/test/QuantizationHelpers.cc
new file mode 100644
index 0000000..354519b
--- /dev/null
+++ b/test/QuantizationHelpers.cc
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "QuantizationHelpers.h"
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+using namespace std;
+
+namespace fbgemm2 {
+/*
+ * @brief Make sure we won't have overflows from vpmaddubsw instruction.
+ */
+template <typename T>
+void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B) {
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ for (int kk = 0; kk < k / 2 * 2; kk += 2) {
+ int a0 = Aint8[i * k + kk], a1 = Aint8[i * k + kk + 1];
+ int b0 = B[kk * n + j], b1 = B[(kk + 1) * n + j];
+ int sum_pair = a0 * b0 + a1 * b1;
+ if (sum_pair < numeric_limits<int16_t>::lowest()) {
+ int b1_adjusted =
+ ceil((numeric_limits<int16_t>::lowest() - a0 * b0) / a1);
+ b1_adjusted = std::min(std::max(b1_adjusted, -128), 127);
+
+ int new_sum_pair = a0 * b0 + a1 * b1_adjusted;
+ assert(
+ new_sum_pair >= numeric_limits<int16_t>::lowest() &&
+ new_sum_pair <= numeric_limits<int16_t>::max());
+ B[(kk + 1) * n + j] = b1_adjusted;
+ } else if (sum_pair > numeric_limits<int16_t>::max()) {
+ int b1_adjusted =
+ floor((numeric_limits<int16_t>::max() - a0 * b0) / a1);
+ b1_adjusted = std::min(std::max(b1_adjusted, -128), 127);
+
+ int new_sum_pair = a0 * b0 + a1 * b1_adjusted;
+ assert(
+ new_sum_pair >= numeric_limits<int16_t>::lowest() &&
+ new_sum_pair <= numeric_limits<int16_t>::max());
+ B[(kk + 1) * n + j] = b1_adjusted;
+ }
+ }
+ } // for each j
+ } // for each i
+}
+
+template void
+avoidOverflow(int m, int n, int k, const uint8_t* Aint8, int8_t* B);
+template void
+avoidOverflow(int m, int n, int k, const uint8_t* Aint8, float* B);
+} // namespace fbgemm2
diff --git a/test/QuantizationHelpers.h b/test/QuantizationHelpers.h
new file mode 100644
index 0000000..fdc6e02
--- /dev/null
+++ b/test/QuantizationHelpers.h
@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cinttypes>
+
+namespace fbgemm2 {
+
+/*
+ * @brief Make sure we won't have overflows from vpmaddubsw instruction.
+ */
+template <typename T>
+void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B);
+
+} // namespace fbgemm2
diff --git a/test/TestUtils.cc b/test/TestUtils.cc
new file mode 100644
index 0000000..702425e
--- /dev/null
+++ b/test/TestUtils.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "TestUtils.h"
+#include <gtest/gtest.h>
+#include "fbgemm/Fbgemm.h"
+#include "bench/AlignedVec.h"
+
+namespace fbgemm2 {
+
+template <typename T>
+int compare_validate_buffers(
+ const T* ref,
+ const T* test,
+ int m,
+ int n,
+ int ld,
+ T atol) {
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ if (std::is_integral<T>::value) {
+ EXPECT_EQ(ref[i * ld + j], test[i * ld + j])
+ << "GEMM results differ at (" << i << ", " << j
+ << ") reference: " << (int64_t)ref[i * ld + j]
+ << ", FBGEMM: " << (int64_t)test[i * ld + j];
+ } else {
+ EXPECT_LE(std::abs(ref[i * ld + j] - test[i * ld + j]), atol)
+ << "GEMM results differ at (" << i << ", " << j
+ << ") reference: " << ref[i * ld + j]
+ << ", FBGEMM: " << test[i * ld + j];
+ }
+ }
+ }
+ return 0;
+}
+
+template int compare_validate_buffers<float>(
+ const float* ref,
+ const float* test,
+ int m,
+ int n,
+ int ld,
+ float atol);
+
+template int compare_validate_buffers<int32_t>(
+ const int32_t* ref,
+ const int32_t* test,
+ int m,
+ int n,
+ int ld,
+ int32_t atol);
+
+template int compare_validate_buffers<uint8_t>(
+ const uint8_t* ref,
+ const uint8_t* test,
+ int m,
+ int n,
+ int ld,
+ uint8_t atol);
+
+template <typename T>
+bool check_all_zero_entries(const T* test, int m, int n) {
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ if (test[i * n + j] != 0)
+ return true;
+ }
+ }
+ return false;
+}
+
+template bool check_all_zero_entries<float>(const float* test, int m, int n);
+template bool
+check_all_zero_entries<int32_t>(const int32_t* test, int m, int n);
+template bool
+check_all_zero_entries<uint8_t>(const uint8_t* test, int m, int n);
+
+template <typename T>
+void transpose_matrix(T* ref, int n, int k) {
+ aligned_vector<T> local(n * k, 0);
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < k; ++j) {
+ local[j * n + i] = ref[i * k + j];
+ }
+ }
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < n; ++j) {
+ ref[i * n + j] = local[i * n + j];
+ }
+ }
+}
+
+template void transpose_matrix<float>(float* ref, int n, int k);
+template void transpose_matrix<int32_t>(int32_t* ref, int n, int k);
+template void transpose_matrix<uint8_t>(uint8_t* ref, int n, int k);
+template void transpose_matrix<int8_t>(int8_t* ref, int n, int k);
+} // namespace fbgemm2
diff --git a/test/TestUtils.h b/test/TestUtils.h
new file mode 100644
index 0000000..559f816
--- /dev/null
+++ b/test/TestUtils.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cmath>
+#include <vector>
+
+namespace fbgemm2 {
+
+/*
+ * @brief Check and validate the buffers for reference and FBGEMM result.
+ */
+template <typename T>
+int compare_validate_buffers(
+ const T* ref,
+ const T* test,
+ int m,
+ int n,
+ int ld,
+ T atol);
+
+/*
+ * @brief Check if all entries are zero or not.
+ * If any entry is non-zero, return True;
+ * otherwise, return False.
+ */
+template <typename T>
+bool check_all_zero_entries(const T* test, int m, int n);
+
+/*
+ * @brief In-place transposition for nxk matrix ref.
+ * @params n number of rows in output (number of columns in input)
+ * @params k number of columns in output (number of rows in input)
+ */
+template <typename T>
+void transpose_matrix(T* ref, int n, int k);
+} // namespace fbgemm2