From e85b5a12254fa47ca6b56236489253a68fd32104 Mon Sep 17 00:00:00 2001 From: Daya S Khudia Date: Fri, 12 Oct 2018 15:48:13 -0700 Subject: Initial commit --- CMakeLists.txt | 171 ++ CODE_OF_CONDUCT.md | 5 + CONTRIBUTING.md | 34 + LICENSE | 30 + README.md | 93 ++ bench/AlignedVec.h | 124 ++ bench/BenchUtils.cc | 44 + bench/BenchUtils.h | 18 + bench/CMakeLists.txt | 43 + bench/Depthwise3DBenchmark.cc | 265 ++++ bench/DepthwiseBenchmark.cc | 342 ++++ bench/FP16Benchmark.cc | 215 +++ bench/I8SpmdmBenchmark.cc | 212 +++ bench/Im2ColFusedRequantizeAcc16Benchmark.cc | 241 +++ bench/Im2ColFusedRequantizeAcc32Benchmark.cc | 242 +++ bench/PackedFloatInOutBenchmark.cc | 269 ++++ bench/PackedRequantizeAcc16Benchmark.cc | 439 +++++ bench/PackedRequantizeAcc32Benchmark.cc | 329 ++++ cmake/modules/DownloadASMJIT.cmake | 16 + cmake/modules/DownloadCPUINFO.cmake | 16 + cmake/modules/DownloadGTEST.cmake | 16 + cmake/modules/FindMKL.cmake | 270 ++++ include/fbgemm/ConvUtils.h | 95 ++ include/fbgemm/Fbgemm.h | 952 +++++++++++ include/fbgemm/FbgemmFP16.h | 160 ++ include/fbgemm/FbgemmI8Spmdm.h | 101 ++ include/fbgemm/OutputProcessing-inl.h | 356 +++++ include/fbgemm/PackingTraits-inl.h | 150 ++ include/fbgemm/Types.h | 115 ++ include/fbgemm/Utils.h | 123 ++ src/ExecuteKernel.cc | 12 + src/ExecuteKernel.h | 11 + src/ExecuteKernelGeneric.h | 64 + src/ExecuteKernelU8S8.cc | 354 +++++ src/ExecuteKernelU8S8.h | 73 + src/Fbgemm.cc | 363 +++++ src/FbgemmFP16.cc | 293 ++++ src/FbgemmFP16UKernels.cc | 2203 ++++++++++++++++++++++++++ src/FbgemmFP16UKernels.h | 40 + src/FbgemmI8Depthwise.cc | 1953 +++++++++++++++++++++++ src/FbgemmI8Depthwise.h | 105 ++ src/FbgemmI8Spmdm.cc | 508 ++++++ src/GenerateKernel.h | 154 ++ src/GenerateKernelU8S8S32ACC16.cc | 292 ++++ src/GenerateKernelU8S8S32ACC16_avx512.cc | 295 ++++ src/GenerateKernelU8S8S32ACC32.cc | 310 ++++ src/GenerateKernelU8S8S32ACC32_avx512.cc | 312 ++++ src/PackAMatrix.cc | 165 ++ src/PackAWithIm2Col.cc | 146 ++ src/PackBMatrix.cc | 144 ++ src/PackMatrix.cc | 86 + src/PackWithQuantRowOffset.cc | 230 +++ src/PackWithRowOffset.cc | 211 +++ src/RefImplementations.cc | 608 +++++++ src/RefImplementations.h | 268 ++++ src/Utils.cc | 357 +++++ src/Utils_avx512.cc | 243 +++ src/codegen_fp16fp32.cc | 387 +++++ test/CMakeLists.txt | 47 + test/FP16Test.cc | 124 ++ test/I8DepthwiseTest.cc | 448 ++++++ test/I8DepthwiseTest.h | 38 + test/I8SpmdmTest.cc | 158 ++ test/PackedRequantizeAcc16Test.cc | 535 +++++++ test/PackedRequantizeTest.cc | 625 ++++++++ test/QuantizationHelpers.cc | 57 + test/QuantizationHelpers.h | 18 + test/TestUtils.cc | 100 ++ test/TestUtils.h | 40 + 69 files changed, 17863 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 bench/AlignedVec.h create mode 100644 bench/BenchUtils.cc create mode 100644 bench/BenchUtils.h create mode 100644 bench/CMakeLists.txt create mode 100644 bench/Depthwise3DBenchmark.cc create mode 100644 bench/DepthwiseBenchmark.cc create mode 100644 bench/FP16Benchmark.cc create mode 100644 bench/I8SpmdmBenchmark.cc create mode 100644 bench/Im2ColFusedRequantizeAcc16Benchmark.cc create mode 100644 bench/Im2ColFusedRequantizeAcc32Benchmark.cc create mode 100644 bench/PackedFloatInOutBenchmark.cc create mode 100644 bench/PackedRequantizeAcc16Benchmark.cc create mode 100644 bench/PackedRequantizeAcc32Benchmark.cc create mode 100644 cmake/modules/DownloadASMJIT.cmake create mode 100644 cmake/modules/DownloadCPUINFO.cmake create mode 100644 cmake/modules/DownloadGTEST.cmake create mode 100644 cmake/modules/FindMKL.cmake create mode 100644 include/fbgemm/ConvUtils.h create mode 100644 include/fbgemm/Fbgemm.h create mode 100644 include/fbgemm/FbgemmFP16.h create mode 100644 include/fbgemm/FbgemmI8Spmdm.h create mode 100644 include/fbgemm/OutputProcessing-inl.h create mode 100644 include/fbgemm/PackingTraits-inl.h create mode 100644 include/fbgemm/Types.h create mode 100644 include/fbgemm/Utils.h create mode 100644 src/ExecuteKernel.cc create mode 100644 src/ExecuteKernel.h create mode 100644 src/ExecuteKernelGeneric.h create mode 100644 src/ExecuteKernelU8S8.cc create mode 100644 src/ExecuteKernelU8S8.h create mode 100644 src/Fbgemm.cc create mode 100644 src/FbgemmFP16.cc create mode 100644 src/FbgemmFP16UKernels.cc create mode 100644 src/FbgemmFP16UKernels.h create mode 100644 src/FbgemmI8Depthwise.cc create mode 100644 src/FbgemmI8Depthwise.h create mode 100644 src/FbgemmI8Spmdm.cc create mode 100644 src/GenerateKernel.h create mode 100644 src/GenerateKernelU8S8S32ACC16.cc create mode 100644 src/GenerateKernelU8S8S32ACC16_avx512.cc create mode 100644 src/GenerateKernelU8S8S32ACC32.cc create mode 100644 src/GenerateKernelU8S8S32ACC32_avx512.cc create mode 100644 src/PackAMatrix.cc create mode 100644 src/PackAWithIm2Col.cc create mode 100644 src/PackBMatrix.cc create mode 100644 src/PackMatrix.cc create mode 100644 src/PackWithQuantRowOffset.cc create mode 100644 src/PackWithRowOffset.cc create mode 100644 src/RefImplementations.cc create mode 100644 src/RefImplementations.h create mode 100644 src/Utils.cc create mode 100644 src/Utils_avx512.cc create mode 100644 src/codegen_fp16fp32.cc create mode 100644 test/CMakeLists.txt create mode 100644 test/FP16Test.cc create mode 100644 test/I8DepthwiseTest.cc create mode 100644 test/I8DepthwiseTest.h create mode 100644 test/I8SpmdmTest.cc create mode 100644 test/PackedRequantizeAcc16Test.cc create mode 100644 test/PackedRequantizeTest.cc create mode 100644 test/QuantizationHelpers.cc create mode 100644 test/QuantizationHelpers.h create mode 100644 test/TestUtils.cc create mode 100644 test/TestUtils.h diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..cfc47b5 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,171 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") + +#install libraries into correct locations on all platforms +include(GNUInstallDirs) + +project(fbgemm VERSION 0.1 LANGUAGES CXX C) + +set(FBGEMM_LIBRARY_TYPE "default" CACHE STRING + "Type of library (shared, static, or default) to build") +set_property(CACHE FBGEMM_LIBRARY_TYPE PROPERTY STRINGS default static shared) +option(FBGEMM_BUILD_TESTS "Build fbgemm unit tests" ON) +option(FBGEMM_BUILD_BENCHMARKS "Build fbgemm benchmarks" ON) + +if(FBGEMM_BUILD_TESTS) + enable_testing() +endif() + +set(FBGEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_SOURCE_DIR}/third-party) +set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +#All the source files that either use avx2 instructions statically or JIT +#avx2/avx512 instructions. +set(FBGEMM_AVX2_SRCS src/ExecuteKernel.cc + src/ExecuteKernelU8S8.cc + src/Fbgemm.cc + src/FbgemmFP16.cc + src/FbgemmFP16UKernels.cc + src/FbgemmI8Depthwise.cc + src/FbgemmI8Spmdm.cc + src/GenerateKernelU8S8S32ACC16.cc + src/GenerateKernelU8S8S32ACC16_avx512.cc + src/GenerateKernelU8S8S32ACC32.cc + src/GenerateKernelU8S8S32ACC32_avx512.cc + src/PackAMatrix.cc + src/PackAWithIm2Col.cc + src/PackBMatrix.cc + src/PackMatrix.cc + src/PackWithQuantRowOffset.cc + src/PackWithRowOffset.cc + src/RefImplementations.cc + src/Utils.cc) + +#check if compiler supports avx512 +include(CheckCXXCompilerFlag) +CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512) +if(NOT COMPILER_SUPPORTS_AVX512) + message(FATAL_ERROR "A compiler with AVX512 support is required.") +endif() + +#All the source files that use avx512 instructions statically +set(FBGEMM_AVX512_SRCS src/Utils_avx512.cc) + +set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h + include/fbgemm/OutputProcessing-inl.h + include/fbgemm/PackingTraits-inl.h + include/fbgemm/Utils.h + include/fbgemm/ConvUtils.h + include/fbgemm/Types.h + include/fbgemm/FbgemmI8Spmdm.h) + + +add_library(fbgemm_avx2 OBJECT ${FBGEMM_AVX2_SRCS}) +add_library(fbgemm_avx512 OBJECT ${FBGEMM_AVX512_SRCS}) + +set_target_properties(fbgemm_avx2 fbgemm_avx512 PROPERTIES + CXX_STANDARD 11 + CXX_EXTENSIONS NO) + +target_compile_options(fbgemm_avx2 PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel") +target_compile_options(fbgemm_avx512 PRIVATE + "-m64" "-mavx2" "-mfma" "-mavx512f" "-masm=intel") + +if(NOT TARGET asmjit) + #Download asmjit from github if ASMJIT_SRC_DIR is not specified. + if(NOT DEFINED ASMJIT_SRC_DIR) + message(STATUS "Downloading asmjit to ${FBGEMM_THIRDPARTY_DIR}/asmjit + (define ASMJIT_SRC_DIR to avoid it)") + configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadASMJIT.cmake" + "${FBGEMM_BINARY_DIR}/asmjit-download/CMakeLists.txt") + execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/asmjit-download") + execute_process(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/asmjit-download") + set(ASMJIT_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/asmjit" CACHE STRING + "asmjit source directory") + endif() + + #build asmjit + set(ASMJIT_STATIC ON) + add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit") +endif() + +if(NOT TARGET cpuinfo) + #Download cpuinfo from github if CPUINFO_SRC_DIR is not specified. + if(NOT DEFINED CPUINFO_SRC_DIR) + message(STATUS "Downloading cpuinfo to ${FBGEMM_THIRDPARTY_DIR}/cpuinfo + (define CPUINFO_SRC_DIR to avoid it)") + configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadCPUINFO.cmake" + "${FBGEMM_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") + execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download") + execute_process(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download") + set(CPUINFO_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING + "cpuinfo source directory") + endif() + + #build cpuinfo + set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "Do not build cpuinfo unit tests") + set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests") + set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks") + set(CPUINFO_LIBRARY_TYPE static) + add_subdirectory("${CPUINFO_SRC_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo") + set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) +endif() + +target_include_directories(fbgemm_avx2 BEFORE + PUBLIC $ + PUBLIC $ + PRIVATE "${ASMJIT_SRC_DIR}/src" + PRIVATE "${CPUINFO_SRC_DIR}/include") + +target_include_directories(fbgemm_avx512 BEFORE + PUBLIC $ + PUBLIC $ + PRIVATE "${ASMJIT_SRC_DIR}/src" + PRIVATE "${CPUINFO_SRC_DIR}/include") + +if(FBGEMM_LIBRARY_TYPE STREQUAL "default") + add_library(fbgemm $ + $) +elseif(FBGEMM_LIBRARY_TYPE STREQUAL "shared") + add_library(fbgemm SHARED $ + $) +elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static") + add_library(fbgemm STATIC $ + $) +else() + message(FATAL_ERROR "Unsupported library type ${FBGEMM_LIBRARY_TYPE}") +endif() + +target_include_directories(fbgemm BEFORE + PUBLIC $ + PUBLIC $) + +target_link_libraries(fbgemm asmjit cpuinfo) +add_dependencies(fbgemm asmjit cpuinfo) + +install(TARGETS fbgemm EXPORT fbgemmLibraryConfig + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows + +install(FILES ${FBGEMM_PUBLIC_HEADERS} + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/fbgemm") + +#Make project importable from the build directory +#export(TARGETS fbgemm asmjit FILE fbgemmLibraryConfig.cmake) + +if(FBGEMM_BUILD_TESTS) + add_subdirectory(test) +endif() + +if(FBGEMM_BUILD_BENCHMARKS) + add_subdirectory(bench) +endif() diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..0f7ad8b --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..042f9e4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing to FBGEMM +We want to make contributing to this project as easy and transparent as +possible. + +## Code of Conduct +The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md). + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to FBGEMM, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8299579 --- /dev/null +++ b/LICENSE @@ -0,0 +1,30 @@ +BSD License + +For FBGEMM software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..2b7efaa --- /dev/null +++ b/README.md @@ -0,0 +1,93 @@ +# FBGEMM +FBGEMM (Facebook GEneral Matrix Multiplication) is a low-precision, +high-performance matrix-matrix multiplications and convolution library for +server-side inference. + +The library provides efficient low-precision general matrix multiplication for +small batch sizes and support for accuracy-loss minimizing techniques such as +row-wise quantization and outlier-aware quantization. FBGEMM also exploits +fusion opportunities in order to overcome the unique challenges of matrix +multiplication at lower precision with bandwidth-bound operations. + +## Examples + +The tests (in test folder) and benchmarks (in bench folder) are some great +examples of using FBGEMM. For instance, SpMDMTest test in +test/PackedRequantizeAcc16Test.cc shows how to combine row offset calculations +with packing of A (PackAWithRowOffset), how to pack B matrix (PackBMatrix) and +construct output pipeline (sparse\_matrix\*dense\_matrix --> requantization --> +nop) fused with inner GEMM macro kernel. + +## Build Notes +FBGEMM uses the standard CMAKE-based build flow. + +### Dependencies +FBGEMM requires gcc 4.9+ and a CPU with support for avx2 instruction set or +higher. It's been tested on Mac OS X and Linux. + ++ ###### asmjit +With inner kernels, FBGEMM takes a “one size doesn't fit all” approach, so the +implementation dynamically generates efficient matrix-shape specific vectorized +code using a third-party library called [asmjit][1]. **asmjit is required** to +build FBGEMM. + ++ ###### cpuinfo +FBGEMM detects CPU instruction set support at runtime using cpuinfo library and +dispatches optimized kernels for the detected instruction set. Therefore, +**cpuinfo is required** to detect CPU type. + ++ ###### googletest +googletest is required to build and run FBGEMM's tests. **googletest is not +required** if you don't want to run FBGEMM tests. By default, building of tests +is **on**. Turn it off by setting FBGEMM\_BUILD\_TESTS to off. + +You can download [asmjit][1], [cpuinfo][2], [googletest][3] and set +ASMJIT\_SRC\_DIR, CPUINFO\_SRC\_DIR, GOOGLETEST\_SOURCE\_DIR respectively for +cmake to find these libraries. If any of these variables is not set, cmake will +try to download that missing library in a folder called third-party in the +current directory and build it using the downloaded source code. + +FBGEMM, in general, does not have any dependency on Intel MKL. However, for +performance comparison, some benchmarks use MKL functions. If MKL is found or +MKL path is provided with INTEL\_MKL\_DIR benchmarks are built with MKL and +performance numbers are reported for MKL functions as well. However, if MKL is +not found, the benchmarks are not built. + +General build instructions are as follows: + +``` +mkdir build && cd build +cmake .. +make +``` + +To run the tests after building FBGEMM (if tests are built), use the following +command: +``` +make test +``` + +## Installing FBGEMM +``` +make install +``` + +## How FBGEMM works +For a high-level overview, design philosophy and brief descriptions of various +parts of FBGEMM please see [our blog][4]. + +## Full documentation +We have extensively used comments in our source files. The best and up-do-date +documentation is available in the source files. + +## Join the FBGEMM community +See the [`CONTRIBUTING`](CONTRIBUTING.md) file for how to help out. + +## License +FBGEMM is BSD licensed, as found in the [`LICENSE`](LICENSE) file. + + +[1]:https://github.com/asmjit/asmjit +[2]:https://github.com/pytorch/cpuinfo +[3]:https://github.com/google/googletest +[4]:https://code.fb.com/ai-research/ diff --git a/bench/AlignedVec.h b/bench/AlignedVec.h new file mode 100644 index 0000000..30fd266 --- /dev/null +++ b/bench/AlignedVec.h @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include +#include + +/** + * Allocator for aligned data. + * + * Modified from the Mallocator from Stephan T. Lavavej. + * + * + */ +template class aligned_allocator { +public: + // The following will be the same for virtually all allocators. + typedef T *pointer; + typedef const T *const_pointer; + typedef T &reference; + typedef const T &const_reference; + typedef T value_type; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + + T *address(T &r) const { return &r; } + + const T *address(const T &s) const { return &s; } + + std::size_t max_size() const { + // The following has been carefully written to be independent of + // the definition of size_t and to avoid signed/unsigned warnings. + return (static_cast(0) - static_cast(1)) / + sizeof(T); + } + + // The following must be the same for all allocators. + template struct rebind { + typedef aligned_allocator other; + }; + + bool operator!=(const aligned_allocator &other) const { + return !(*this == other); + } + + void construct(T *const p, const T &t) const { + void *const pv = static_cast(p); + + new (pv) T(t); + } + + void destroy(T *const p) const { p->~T(); } + + // Returns true if and only if storage allocated from *this + // can be deallocated from other, and vice versa. + // Always returns true for stateless allocators. + bool operator==(const aligned_allocator & /*other*/) const { return true; } + + // Default constructor, copy constructor, rebinding constructor, and + // destructor. Empty for stateless allocators. + aligned_allocator() {} + + aligned_allocator(const aligned_allocator &) {} + + template + aligned_allocator(const aligned_allocator &) {} + + ~aligned_allocator() {} + + // The following will be different for each allocator. + T *allocate(const std::size_t n) const { + // The return value of allocate(0) is unspecified. + // Mallocator returns NULL in order to avoid depending + // on malloc(0)'s implementation-defined behavior + // (the implementation can define malloc(0) to return NULL, + // in which case the bad_alloc check below would fire). + // All allocators can return NULL in this case. + if (n == 0) { + return nullptr; + } + + // All allocators should contain an integer overflow check. + // The Standardization Committee recommends that std::length_error + // be thrown in the case of integer overflow. + if (n > max_size()) { + throw std::length_error( + "aligned_allocator::allocate() - Integer overflow."); + } + + // Mallocator wraps malloc(). + void *pv = nullptr; + posix_memalign(&pv, Alignment, n * sizeof(T)); + // pv = aligned_alloc(Alignment, n * sizeof(T)); + + // Allocators should throw std::bad_alloc in the case of memory allocation + // failure. + if (pv == nullptr) { + throw std::bad_alloc(); + } + + return static_cast(pv); + } + + void deallocate(T *const p, const std::size_t /*n*/) const { free(p); } + + // The following will be the same for all allocators that ignore hints. + template + T *allocate(const std::size_t n, const U * /* const hint */) const { + return allocate(n); + } + + // Allocators are not required to be assignable, so + // all allocators should have a private unimplemented + // assignment operator. Note that this will trigger the + // off-by-default (enabled under /Wall) warning C4626 + // "assignment operator could not be generated because a + // base class assignment operator is inaccessible" within + // the STL headers, but that warning is useless. +private: + aligned_allocator &operator=(const aligned_allocator &) { assert(0); } +}; + +template +using aligned_vector = std::vector >; diff --git a/bench/BenchUtils.cc b/bench/BenchUtils.cc new file mode 100644 index 0000000..5dade2a --- /dev/null +++ b/bench/BenchUtils.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "BenchUtils.h" +#include + +namespace fbgemm2 { + +std::default_random_engine eng; + +template +void randFill(aligned_vector &vec, const int low, const int high) { + std::random_device r; + std::uniform_int_distribution dis(low, high); + for (auto &v : vec) { + v = static_cast(dis(eng)); + } +} + +template +void randFill(aligned_vector &vec, + const int low, const int high); +template +void randFill(aligned_vector &vec, + const int low, const int high); +template +void randFill(aligned_vector &vec, + const int low, const int high); + +template +void randFill(aligned_vector &vec, + const int low, const int high); + +void llc_flush(std::vector& llc) { + volatile char* data = llc.data(); + for (int i = 0; i < llc.size(); i++) { + data[i]++; + } +} + +} // namespace fbgemm2 diff --git a/bench/BenchUtils.h b/bench/BenchUtils.h new file mode 100644 index 0000000..5dd452e --- /dev/null +++ b/bench/BenchUtils.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include "bench/AlignedVec.h" + +namespace fbgemm2 { + +template +void randFill(aligned_vector &vec, const int low, const int high); + +void llc_flush(std::vector& llc); + +} // namespace fbgemm2 diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt new file mode 100644 index 0000000..338c37c --- /dev/null +++ b/bench/CMakeLists.txt @@ -0,0 +1,43 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +find_package(MKL) +#benchmarks +macro(add_benchmark BENCHNAME) + add_executable(${BENCHNAME} ${ARGN} + BenchUtils.cc ../test/QuantizationHelpers.cc) + set_target_properties(${BENCHNAME} PROPERTIES + CXX_STANDARD 11 + CXX_EXTENSIONS NO) + target_compile_options(${BENCHNAME} PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel") + target_link_libraries(${BENCHNAME} fbgemm) + add_dependencies(${BENCHNAME} fbgemm) + if(${MKL_FOUND}) + target_include_directories(${BENCHNAME} PRIVATE "${MKL_INCLUDE_DIR}") + target_link_libraries(${BENCHNAME} "${MKL_LIBRARIES}") + target_compile_options(${BENCHNAME} PRIVATE + "-DUSE_MKL") + endif() + set_target_properties(${BENCHNAME} PROPERTIES FOLDER test) +endmacro() + +if(FBGEMM_BUILD_BENCHMARKS) + + set(BENCHMARKS "") + + file(GLOB BENCH_LIST "*Benchmark.cc") + foreach(BENCH_FILE ${BENCH_LIST}) + get_filename_component(BENCH_NAME "${BENCH_FILE}" NAME_WE) + get_filename_component(BENCH_FILE_ONLY "${BENCH_FILE}" NAME) + add_benchmark("${BENCH_NAME}" + "${BENCH_FILE_ONLY}") + list(APPEND BENCHMARKS "${BENCH_NAME}") + endforeach() + + add_custom_target(run_benchmarks + COMMAND ${BENCHMARKS}) + + add_dependencies(run_benchmarks + ${BENCHMARKS}) + +endif() diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc new file mode 100644 index 0000000..b7c7d44 --- /dev/null +++ b/bench/Depthwise3DBenchmark.cc @@ -0,0 +1,265 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "test/I8DepthwiseTest.h" + +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "AlignedVec.h" +#include "src/FbgemmI8Depthwise.h" +#include "fbgemm/Utils.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +int main() { + // Depthwise is memory BW bound so we want to flush LLC. + bool flush = true; + std::vector llc; + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } +#define llc_flush() \ + for (auto i = 0; i < llc.size(); i++) { \ + llc[i]++; \ + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 16; + + for (auto shape : shapes_3d) { + int N = shape[0]; + int K = shape[1]; + int T = shape[2]; + int H = shape[3]; + int W = shape[4]; + int stride_t = shape[5]; + int stride_h = stride_t; + int stride_w = stride_t; + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + aligned_vector A(N * T * H * W * K); + aligned_vector B(K * K_T * K_H * K_W); + aligned_vector C_ref(N * T_OUT * H_OUT * W_OUT * K), + C(C_ref.size()); + + randFill(A, 0, 86); + int32_t A_zero_point = 43; + + randFill(B, -16, 16); + int32_t B_zero_point = 5; + + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); + int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + + float C_multiplier = 255. / (maximum - minimum); + + aligned_vector col_offsets(K); + aligned_vector bias(K); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + int32_t C_zero_point = 5; + + aligned_vector C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + B.data(), + C_multiplier, + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3x3ConvMatrix Bp(K, B.data()); + + double ttot = 0; + double bytes = + double(NITER) * + (K * (N * (2. * sizeof(int32_t) * T_OUT * H_OUT * W_OUT + T * H * W) + + K_T * K_H * K_W)); + double ops = + double(NITER) * N * T_OUT * H_OUT * W_OUT * K * K_T * K_H * K_W * 2; + chrono::time_point t_begin, t_end; + for (int i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(); + + t_begin = chrono::system_clock::now(); +#pragma omp parallel + { +#if _OPENMP + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#else + int num_threads = 1; + int tid = 0; +#endif + depthwise_3x3x3_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + Bp, + C.data(), + tid, + num_threads); + } + t_end = chrono::system_clock::now(); + if (i >= NWARMUP) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + + // correctness check + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int g = 0; g < K; ++g) { + int32_t expected = + C_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g]; + int32_t actual = + C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g]; + if (expected != actual) { + cerr << "Depthwise 3x3 results differ at (" << n << ", " << t + << ", " << h << ", " << w << ", " << g << "). expected " + << expected << " actual " << actual << endl; + return -1; + } + assert(expected == actual); + } + } // w + } // h + } // t + } // n + + // Report performance + printf("N = %d K = %d T = %d H = %d W = %d stride = %d\n", N, K, T, H, W, + stride_h); + printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9); + + ttot = 0; + for (int i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(); + + t_begin = chrono::system_clock::now(); +#pragma omp parallel + { +#if _OPENMP + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#else + int num_threads = 1; + int tid = 0; +#endif + depthwise_3x3x3_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_uint8.data(), + col_offsets.data(), + bias.data(), + false /* fuse_relu */, + tid, + num_threads); + } + t_end = chrono::system_clock::now(); + if (i >= NWARMUP) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + + // correctness check + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int g = 0; g < K; ++g) { + uint8_t expected = + C_uint8_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + + g]; + uint8_t actual = + C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + g]; + if (expected != actual) { + cerr << "Depthwise 3x3 results differ at (" << n << ", " << t + << ", " << h << ", " << w << ", " << g << "). expected " + << (int)expected << " actual " << (int)actual << endl; + return -1; + } + assert(expected == actual); + } + } // w + } // h + } // t + } // n + + // Report performance + printf("N = %d K = %d T = %d H = %d W = %d stride = %d with requantization " + "fused\n", + N, K, T, H, W, stride_h); + printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9); + } // for each shape + + return 0; +} diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc new file mode 100644 index 0000000..de08ff7 --- /dev/null +++ b/bench/DepthwiseBenchmark.cc @@ -0,0 +1,342 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "AlignedVec.h" +#include "src/FbgemmI8Depthwise.h" +#include "fbgemm/Utils.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +int main() { + // From Xray OCR + vector> shapes = { + // N, G, H_in, W_in, stride + { 1, 272, 47, 125, 1, }, + { 1, 272, 64, 125, 1, }, + { 1, 272, 66, 125, 1, }, + { 1, 272, 67, 100, 1, }, + { 1, 272, 71, 125, 1, }, + { 1, 272, 74, 125, 1, }, + { 1, 272, 75, 75, 1, }, + { 1, 272, 75, 76, 1, }, + { 1, 272, 75, 79, 1, }, + { 1, 272, 75, 85, 1, }, + { 1, 272, 75, 100, 1, }, + { 1, 272, 75, 103, 1, }, + { 1, 272, 75, 111, 1, }, + { 1, 272, 75, 113, 1, }, + { 1, 272, 94, 75, 1, }, + { 1, 272, 109, 75, 1, }, + { 1, 272, 113, 75, 1, }, + { 1, 272, 117, 75, 1, }, + { 1, 544, 24, 63, 1, }, + { 1, 544, 32, 63, 1, }, + { 1, 544, 33, 63, 1, }, + { 1, 544, 34, 50, 1, }, + { 1, 544, 36, 63, 1, }, + { 1, 544, 37, 63, 1, }, + { 1, 544, 38, 38, 1, }, + { 1, 544, 38, 40, 1, }, + { 1, 544, 38, 43, 1, }, + { 1, 544, 38, 50, 1, }, + { 1, 544, 38, 52, 1, }, + { 1, 544, 38, 56, 1, }, + { 1, 544, 38, 57, 1, }, + { 1, 544, 47, 38, 1, }, + { 1, 544, 55, 38, 1, }, + { 1, 544, 57, 38, 1, }, + { 1, 544, 59, 38, 1, }, + { 1, 1088, 7, 7, 1, }, + { 51, 1088, 7, 7, 1, }, + { 59, 1088, 7, 7, 1, }, + { 70, 1088, 7, 7, 1, }, + { 71, 1088, 7, 7, 1, }, + { 77, 1088, 7, 7, 1, }, + { 79, 1088, 7, 7, 1, }, + { 84, 1088, 7, 7, 1, }, + { 85, 1088, 7, 7, 1, }, + { 89, 1088, 7, 7, 1, }, + { 93, 1088, 7, 7, 1, }, + { 96, 1088, 7, 7, 1, }, + { 100, 1088, 7, 7, 1, }, + + { 1, 248, 93, 250, 2, }, + { 1, 248, 128, 250, 2, }, + { 1, 248, 132, 250, 2, }, + { 1, 248, 131, 250, 2, }, + { 1, 248, 133, 200, 2, }, + { 1, 248, 141, 250, 2, }, + { 1, 248, 148, 250, 2, }, + { 1, 248, 150, 150, 2, }, + { 1, 248, 150, 151, 2, }, + { 1, 248, 150, 158, 2, }, + { 1, 248, 150, 169, 2, }, + { 1, 248, 150, 200, 2, }, + { 1, 248, 150, 205, 2, }, + { 1, 248, 150, 221, 2, }, + { 1, 248, 150, 225, 2, }, + { 1, 248, 188, 150, 2, }, + { 1, 248, 218, 150, 2, }, + { 1, 248, 225, 150, 2, }, + { 1, 248, 234, 150, 2, }, + { 1, 272, 47, 125, 2, }, + { 1, 272, 64, 125, 2, }, + { 1, 272, 66, 125, 2, }, + { 1, 272, 67, 100, 2, }, + { 1, 272, 71, 125, 2, }, + { 1, 272, 74, 125, 2, }, + { 1, 272, 75, 75, 2, }, + { 1, 272, 75, 76, 2, }, + { 1, 272, 75, 79, 2, }, + { 1, 272, 75, 85, 2, }, + { 1, 272, 75, 100, 2, }, + { 1, 272, 75, 103, 2, }, + { 1, 272, 75, 111, 2, }, + { 1, 272, 75, 113, 2, }, + { 1, 272, 94, 75, 2, }, + { 1, 272, 109, 75, 2, }, + { 1, 272, 113, 75, 2, }, + { 1, 272, 117, 75, 2, }, + { 1, 544, 14, 14, 2, }, + { 51, 544, 14, 14, 2, }, + { 59, 544, 14, 14, 2, }, + { 70, 544, 14, 14, 2, }, + { 71, 544, 14, 14, 2, }, + { 77, 544, 14, 14, 2, }, + { 79, 544, 14, 14, 2, }, + { 84, 544, 14, 14, 2, }, + { 85, 544, 14, 14, 2, }, + { 89, 544, 14, 14, 2, }, + { 93, 544, 14, 14, 2, }, + { 96, 544, 14, 14, 2, }, + { 100, 544, 14, 14, 2, }, + }; + + // Depthwise is memory BW bound so we want to flush LLC. + bool flush = true; + std::vector llc; + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } +#define llc_flush() \ + for (auto i = 0; i < llc.size(); i++) { \ + llc[i]++; \ + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 16; + + for (auto shape : shapes) { + int N = shape[0]; + int G = shape[1]; + int H = shape[2]; + int W = shape[3]; + int stride_h = shape[4]; + int stride_w = stride_h; + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + aligned_vector A(N * H * W * G); + aligned_vector B(G * R * S); + aligned_vector C_ref(N * H_OUT * W_OUT * G), C(C_ref.size()); + + randFill(A, 0, 86); + int32_t A_zero_point = 43; + + randFill(B, -16, 16); + int32_t B_zero_point = 5; + + depthwise_3x3_pad_1_ref( + N, + H, + W, + G, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); + int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + + float C_multiplier = 255. / (maximum - minimum); + + aligned_vector col_offsets(G); + aligned_vector bias(G); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + int32_t C_zero_point = 5; + + aligned_vector C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3_pad_1_ref( + N, + H, + W, + G, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + B.data(), + C_multiplier, + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3ConvMatrix Bp(G, B.data()); + + double ttot = 0; + double bytes = + double(NITER) * + (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S)); + double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2; + chrono::time_point t_begin, t_end; + for (int i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(); + + t_begin = chrono::system_clock::now(); +#pragma omp parallel + { +#ifdef _OPENMP + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#else + int num_threads = 1; + int tid = 0; +#endif + depthwise_3x3_pad_1( + N, + H, + W, + G, + stride_h, + stride_w, + A_zero_point, + A.data(), + Bp, + C.data(), + tid, + num_threads); + } + t_end = chrono::system_clock::now(); + if (i >= NWARMUP) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + + // correctness check + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int g = 0; g < G; ++g) { + int32_t expected = C_ref[((n * H_OUT + h) * W_OUT + w) * G + g]; + int32_t actual = C[((n * H_OUT + h) * W_OUT + w) * G + g]; + if (expected != actual) { + cerr << "Depthwise 3x3 results differ at (" << n << ", " + << h << ", " << w << ", " << g << "). expected " + << expected << " actual " << actual << endl; + return -1; + } + assert(expected == actual); + } + } + } + } + + // Report performance + printf("N = %d G = %d H = %d W = %d stride = %d\n", N, G, H, W, stride_h); + printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9); + + ttot = 0; + for (int i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(); + + t_begin = chrono::system_clock::now(); +#pragma omp parallel + { +#ifdef _OPENMP + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#else + int num_threads = 1; + int tid = 0; +#endif + depthwise_3x3_pad_1( + N, + H, + W, + G, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_uint8.data(), + col_offsets.data(), + bias.data(), + tid, + num_threads); + } + t_end = chrono::system_clock::now(); + if (i >= NWARMUP) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + + // correctness check + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int g = 0; g < G; ++g) { + uint8_t expected = + C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g]; + uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g]; + if (expected != actual) { + cerr << "Depthwise 3x3 results differ at (" << n << ", " + << h << ", " << w << ", " << g << "). expected " + << (int)expected << " actual " << (int)actual << endl; + return -1; + } + assert(expected == actual); + } + } + } + } + + // Report performance + printf( + "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n", + N, G, H, W, stride_h); + printf("GB/s = %f Gops/s = %f\n", bytes / ttot / 1e9, ops / ttot / 1e9); + } // for each shape + + return 0; +} diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc new file mode 100644 index 0000000..f5ec10f --- /dev/null +++ b/bench/FP16Benchmark.cc @@ -0,0 +1,215 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#ifdef USE_MKL +#include +#endif + +#ifdef _OPENMP +#include +#endif + +#include "bench/BenchUtils.h" +#include "fbgemm/FbgemmFP16.h" +#include "AlignedVec.h" + +using namespace std; +using namespace fbgemm2; + +void performance_test() { + // cache flush + bool flush = true; + std::vector llc; + if (flush) { + llc.resize(64L * 1024L * 1024L, 1.0); + } + + float alpha = 1.f, beta = 1.f; + matrix_op_t btran = matrix_op_t::Transpose; + + using btype = float16; + +#define dataset 1 + +#if dataset == 1 + const int NITER = (flush) ? 10 : 100; + std::vector> shapes; + for (auto m = 1; m < 120; m++) { + // shapes.push_back({m, 128, 512}); + shapes.push_back({m, 512, 512}); + } + +#elif dataset == 2 + const int NITER = (flush) ? 10 : 100; +#include "shapes_dataset.h" + +#else + flush = false; + constexpr int NITER = 1; + std::vector> shapes; + std::random_device r; + std::default_random_engine generator(r()); + std::uniform_int_distribution dm(1, 100); + std::uniform_int_distribution dnk(1, 1024); + for (int i = 0; i < 1000; i++) { + int m = dm(generator); + int n = dnk(generator); + int k = dnk(generator); + shapes.push_back({m, n, k}); + } +#endif + + std::string type; + double gflops, gbs, ttot; + for (auto s : shapes) { + int m = s[0]; + int n = s[1]; + int k = s[2]; + + aligned_vector A(m * k, 0.f); + aligned_vector B(k * n, 0.f); + aligned_vector Cg(m * n, 1.f); + aligned_vector Cp(m * n, NAN); + + // initialize with small numbers + randFill(A, 0, 4); + + randFill(B, 0, 4); + PackedGemmMatrixFP16 Bp(btran, k, n, alpha, B.data()); + + if (beta != 0.0f) { + randFill(Cg, 0, 4); + Cp = Cg; + } + + double nflops = 2.0 * (double)m * (double)n * (double)k * (double)NITER; + double nbytes = (4.0 * (double)m * (double)k + 2.0 * (double)k * (double)n + + 4.0 * (double)m * (double)n) * + NITER; + + // warm up MKL and fbgemm + // check correctness at the same time + for (auto w = 0; w < 3; w++) { +#ifdef USE_MKL + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + btran == matrix_op_t::Transpose ? CblasTrans : CblasNoTrans, + m, + n, + k, + alpha, + A.data(), + k, + B.data(), + (btran == matrix_op_t::NoTranspose) ? n : k, + beta, + Cg.data(), + n); +#endif + cblas_gemm_compute( + matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data()); + +#ifdef USE_MKL + // Compare results + for (auto i = 0; i < Cg.size(); i++) { + // printf("%f %f\n", Cg[i], Cp[i]); + assert(std::abs(Cg[i] - Cp[i]) < 1e-3); + } +#endif + } + + chrono::time_point t_begin, t_end; +#ifdef USE_MKL + // Gold via MKL sgemm + type = "MKL_FP32"; + ttot = 0; + for (auto it = -3; it < NITER; it++) { + if (flush) { + for (auto i = 0; i < llc.size(); i++) { + llc[i]++; + } + } + t_begin = chrono::system_clock::now(); + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + btran == matrix_op_t::Transpose ? CblasTrans : CblasNoTrans, + m, + n, + k, + alpha, + A.data(), + k, + B.data(), + (btran == matrix_op_t::NoTranspose) ? n : k, + beta, + Cg.data(), + n); + t_end = chrono::system_clock::now(); + if (it >= 0) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + gflops = nflops / ttot / 1e9; + gbs = nbytes / ttot / 1e9; + printf( + "\n%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n", + type.c_str(), + m, + n, + k, + gflops, + gbs); + ((volatile char*)(llc.data())); +#endif + + type = "FBP_" + std::string(typeid(btype).name()); + + ttot = 0; + for (auto it = -3; it < NITER; it++) { + if (flush) { + for (auto i = 0; i < llc.size(); i++) { + llc[i]++; + } + } + + t_begin = chrono::system_clock::now(); + cblas_gemm_compute( + matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data()); + t_end = chrono::system_clock::now(); + + if (it >= 0) { + double dt = chrono::duration(t_end - t_begin).count(); + ttot += dt; + } + } + gflops = nflops / ttot / 1e9; + gbs = nbytes / ttot / 1e9; + printf( + "%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n", + type.c_str(), + m, + n, + k, + gflops, + gbs); + ((volatile char*)(llc.data())); + } +} + +int main(int /*argc*/, char** /*argv*/) { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + + performance_test(); +} diff --git a/bench/I8SpmdmBenchmark.cc b/bench/I8SpmdmBenchmark.cc new file mode 100644 index 0000000..f97d152 --- /dev/null +++ b/bench/I8SpmdmBenchmark.cc @@ -0,0 +1,212 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "fbgemm/FbgemmI8Spmdm.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +int main() { + const vector> shapes = { + // M, N, K + {1024, 1024, 1024}, + {511, 512, 512}, + }; + + // SpMDM is often memory BW bound so we want to flush LLC. + bool flush = true; + std::vector llc; + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 16; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "M, " + << "N, " + << "K, " + << "Density, " + << "Accumulation, " + << "Initialize (ms), " + << "Transpose uint8 (ms), " + << "Transpose 32xN (ms), " + << "Compute (ms), " + << "Transpose 32xN (ms), " + << "Total (ms), " + << "GB/s, " + << "GOPs" << endl; +#else + cout << "M, " + << "N, " + << "K, " + << "Density, " + << "Accumulation, " + << "GB/s, " + << "GOPs" << endl; +#endif + + for (const auto& shape : shapes) { + for (float density : {0.0001f, 0.001f, 0.01f, 0.1f, 1.0f}) { + for (bool accumulation : {false, true}) { + int M = shape[0]; + int N = shape[1]; + int K = shape[2]; + + cout << M << ", " << N << ", " << K << ", "; + + aligned_vector A(M * K); + randFill(A, 0, 255); + + fbgemm2::CompressedSparseColumn B_csc(K, N); + vector C(M * N); + vector C_ref(C.size()); + + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + C_ref[i * N + j] = i + j; + } + } + + // deterministic random number + std::default_random_engine eng; + binomial_distribution<> per_col_nnz_dist(K, density); + uniform_int_distribution<> value_dist( + numeric_limits::min() / 2, + numeric_limits::max() / 2); + + vector row_indices(K); + + int total_nnz = 0; + for (int j = 0; j < N; ++j) { + B_csc.ColPtr()[j] = total_nnz; + + int nnz_of_j = per_col_nnz_dist(eng); + total_nnz += nnz_of_j; + + iota(row_indices.begin(), row_indices.end(), 0); + shuffle(row_indices.begin(), row_indices.end(), eng); + sort(row_indices.begin(), row_indices.begin() + nnz_of_j); + + for (int k = 0; k < nnz_of_j; ++k) { + B_csc.RowIdx().push_back(row_indices[k]); + B_csc.Values().push_back(value_dist(eng)); + } + } + B_csc.ColPtr()[N] = total_nnz; + + double ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double total_initial_time = 0.0; + double total_transpose_uint8_time = 0.0; + double total_transpose_32xN_time = 0.0; + double total_compute_time = 0.0; + double total_transpose_Nx32_time = 0.0; + double total_run_time = 0.0; +#endif + double ops = double(NITER) * B_csc.NumOfNonZeros() * M * 2; + double bytes = double(NITER) * + (M * N * sizeof(int32_t) + M * K + + B_csc.NumOfNonZeros() * (sizeof(int16_t) + sizeof(int8_t)) + + B_csc.ColPtr().size() * sizeof(int32_t)); + + spmdm_ref(M, A.data(), K, B_csc, accumulation, C_ref.data(), N); + + chrono::time_point t_begin, t_end; + for (int iter = 0; iter < NWARMUP + NITER; ++iter) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + C[i * N + j] = i + j; + } + } + llc_flush(llc); + + t_begin = chrono::system_clock::now(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + spmdm_initial_time = 0.0; + spmdm_transpose_uint8_time = 0.0; + spmdm_transpose_32xN_time = 0.0; + spmdm_compute_time = 0.0; + spmdm_transpose_Nx32_time = 0.0; + spmdm_run_time = 0.0; +#endif + +#ifndef FBGEMM_MEASURE_TIME_BREAKDOWN +#pragma omp parallel +#endif + { +#if defined (FBGEMM_MEASURE_TIME_BREAKDOWN) || !defined(_OPENMP) + int num_threads = 1; + int tid = 0; +#else + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#endif + int i_per_thread = + ((M + 31) / 32 + num_threads - 1) / num_threads * 32; + int i_begin = std::min(tid * i_per_thread, M); + int i_end = std::min(i_begin + i_per_thread, M); + + block_type_t block = {i_begin, i_end - i_begin, 0, N}; + B_csc.SpMDM( + block, A.data(), K, accumulation, C.data() + i_begin * N, N); + } + t_end = chrono::system_clock::now(); + if (iter >= NWARMUP) { + double dt = chrono::duration(t_end - t_begin).count(); + // double dt = chrono::duration_cast(t_end - + // t_begin).count(); + ttot += dt; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_initial_time += spmdm_initial_time; + total_transpose_uint8_time += spmdm_transpose_uint8_time; + total_transpose_32xN_time += spmdm_transpose_32xN_time; + total_compute_time += spmdm_compute_time; + total_transpose_Nx32_time += spmdm_transpose_Nx32_time; + total_run_time += spmdm_run_time; +#endif + } + } + + compare_buffers(C_ref.data(), C.data(), M, N, N, 5, 0); + + cout << fixed << B_csc.Density() << ", " << accumulation << ", "; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << total_initial_time / (double)NITER / 1e6 << ", " + << total_transpose_uint8_time / (double)NITER / 1e6 << ", " + << total_transpose_32xN_time / (double)NITER / 1e6 << ", " + << total_compute_time / (double)NITER / 1e6 << ", " + << total_transpose_Nx32_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", "; +#endif + // Report performance + cout << fixed << bytes / ttot / 1e9 << ", " << ops / ttot / 1e9 << endl; + + } // accumulation + } // for each density + } // for each shape + + return 0; +} diff --git a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc new file mode 100644 index 0000000..62010ec --- /dev/null +++ b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc @@ -0,0 +1,241 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +void performance_test() { + vector shapes = { + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ), + }; + + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + + chrono::time_point begin, end; + for (auto conv_p : shapes) { + aligned_vector Afp32( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f); + aligned_vector Aint8( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0); + + aligned_vector Aint8_out( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC, + 0); + + aligned_vector Bfp32( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f); + aligned_vector Bint8( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); + + aligned_vector Cint32_ref( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + + aligned_vector Cint32_fb( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + aligned_vector Cint32_fb2( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + cout << conv_p.toString() << endl; + + // A matrix (input activations) + randFill(Afp32, 0, 5); + int32_t Aint8_zero_point = 4; + for (auto i = 0; i < Afp32.size(); ++i) { + Aint8[i] = static_cast(Afp32[i]); + } + + // B matrix (weights) + randFill(Bfp32, -4, 4); + // int32_t Bint8_zero_point = -3; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bint8[i] = static_cast(Bfp32[i]); + } + + // reference implementation + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + // matrix dimensions after im2col + int MDim = conv_p.MB * conv_p.OH * conv_p.OW; + int NDim = conv_p.OC; + int KDim = conv_p.KH * conv_p.KW * conv_p.IC; + + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim, + // "B unpacked"); + // packedB.printPackedMatrix("B Packed"); + + double ttot = 0; + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithIm2Col::rowOffsetBufferSize()); + + PackAWithIm2Col packA( + conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data()); + + PackBMatrix packedB( + matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim); + + // no-op output process objects + DoNothing doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + + ttot = 0; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + fbgemmPacked( + packA, + packedB, + Cint32_fb.data(), + Cint32_fb.data(), + NDim, + memcopyObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); + } + } + cout << fixed << "fused im2col GOPs: " + << static_cast(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + + compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5); + + ttot = 0; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data()); + + // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim, + // KDim, "A_out after im2col unpacked"); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + MDim, + KDim, + Aint8_out.data(), + KDim, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + fbgemmPacked( + packAN, + packedB, + Cint32_fb2.data(), + Cint32_fb2.data(), + NDim, + memcopyObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); + } + } + ((volatile char*)(llc.data())); + + // packedB.printPackedMatrix("bench B Packed"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_fb.data(), MDim, NDim, NDim, + // "C fb fp32"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_fb2.data(), + // MDim, NDim, NDim, "C fb2 fp32"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32"); + + cout << fixed << "unfused im2col GOPs: " + << static_cast(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + // cout << "total time: " << ttot << " ns" << endl; + compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5); + } // shapes +} + +int main() { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc new file mode 100644 index 0000000..9adea49 --- /dev/null +++ b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc @@ -0,0 +1,242 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +void performance_test() { + vector shapes = { + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ), + }; + + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + + chrono::time_point begin, end; + for (auto conv_p : shapes) { + aligned_vector Afp32( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f); + aligned_vector Aint8( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0); + + aligned_vector Aint8_out( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC, + 0); + + aligned_vector Bfp32( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f); + aligned_vector Bint8( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); + + aligned_vector Cint32_ref( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + + aligned_vector Cint32_fb( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + aligned_vector Cint32_fb2( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + cout << conv_p.toString() << endl; + + // A matrix (input activations) + randFill(Afp32, 0, 5); + int32_t Aint8_zero_point = 4; + for (auto i = 0; i < Afp32.size(); ++i) { + Aint8[i] = static_cast(Afp32[i]); + } + + // B matrix (weights) + randFill(Bfp32, -4, 4); + // int32_t Bint8_zero_point = -3; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bint8[i] = static_cast(Bfp32[i]); + } + + // reference implementation + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + // matrix dimensions after im2col + int MDim = conv_p.MB * conv_p.OH * conv_p.OW; + int NDim = conv_p.OC; + int KDim = conv_p.KH * conv_p.KW * conv_p.IC; + + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim, + // "B unpacked"); + // packedB.printPackedMatrix("B Packed"); + + double ttot = 0; + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithIm2Col::rowOffsetBufferSize()); + + PackAWithIm2Col packA( + conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data()); + + PackBMatrix packedB( + matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim); + + // no-op output process objects + DoNothing doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + + ttot = 0; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + fbgemmPacked( + packA, + packedB, + Cint32_fb.data(), + Cint32_fb.data(), + NDim, + memcopyObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); + } + } + cout << fixed << "fused im2col GOPs: " + << static_cast(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + + compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5); + + ttot = 0; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data()); + + // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim, + // KDim, "A_out after im2col unpacked"); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + MDim, + KDim, + Aint8_out.data(), + KDim, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + fbgemmPacked( + packAN, + packedB, + Cint32_fb2.data(), + Cint32_fb2.data(), + NDim, + memcopyObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); + } + } + + ((volatile char*)(llc.data())); + + // packedB.printPackedMatrix("bench B Packed"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_fb.data(), MDim, NDim, NDim, + // "C fb fp32"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_fb2.data(), + // MDim, NDim, NDim, "C fb2 fp32"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32"); + + cout << fixed << "unfused im2col GOPs: " + << static_cast(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + // cout << "total time: " << ttot << " ns" << endl; + compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5); + } // shapes +} + +int main() { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc new file mode 100644 index 0000000..4a2eda4 --- /dev/null +++ b/bench/PackedFloatInOutBenchmark.cc @@ -0,0 +1,269 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#ifdef USE_MKL +#include +#endif + +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "test/QuantizationHelpers.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +void performance_test() { + vector> shapes = { + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 4096, 1024}, + + {6, 256, 1024}, + {6, 256, 2048}, + {6, 512, 512}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 512}, + {6, 4096, 256}, + {6, 4096, 1024}, + {6, 4096, 2048}, + + {10, 2048, 256}, + {10, 4096, 1024}, + + {20, 2048, 256}, + {20, 4096, 1024}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + + {1, 800, 3200}, + {1, 800, 8000}, + + {16, 256, 1500}, + {16, 256, 1567}, + {1, 128, 2876}, + {16, 128, 1567}, + {1, 128, 2722}, + {16, 256, 512}, + }; + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + + cout << setw(7) << "M, " << setw(7) << "N, " << setw(7) << "K, " << setw(18) + << "Type, " << setw(5) << "GOPS" << endl; + + chrono::time_point start, end; + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + float alpha = 1.f, beta = 0.f; + aligned_vector Afp32(m * k, 0.0f); + aligned_vector Aint8(m * k, 0); + + aligned_vector Bfp32(k * n, 0.0f); + aligned_vector Bint8(k * n, 0); + + aligned_vector Cfp32_mkl(m * n, 0.0f); + aligned_vector Cfp32_fb(m * n, 0.0f); + + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + + // A matrix + randFill(Aint8, 0, 255); + float Aint8_scale = 0.11; + int32_t Aint8_zero_point = 43; + for (auto i = 0; i < Afp32.size(); ++i) { + Afp32[i] = Aint8_scale * (Aint8[i] - Aint8_zero_point); + } + + randFill(Bint8, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bint8.data()); + + float Bint8_scale = 0.49; + int32_t Bint8_zero_point = -30; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bfp32[i] = Bint8_scale * (Bint8[i] - Bint8_zero_point); + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n); + col_offsets_with_zero_pt_s8acc32_ref( + k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data()); + + double ttot = 0; + std::string type; + double nops = 2.0 * (double)m * (double)n * (double)k * (double)NITER; +#ifdef USE_MKL + type = "MKL_FP32"; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + start = chrono::high_resolution_clock::now(); + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + Afp32.data(), + k, + Bfp32.data(), + n, + beta, + Cfp32_mkl.data(), + n); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - start); + ttot += dur.count(); + } + } + ((volatile char*)(llc.data())); + cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k << ", " + << setw(16) << type << ", " << setw(5) << fixed << setw(5) + << setprecision(1) << nops / ttot << endl; +#endif + + int32_t C_multiplier = 16544; + int32_t C_right_shift = 35; + int32_t C_zero_pt = 5; + + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B + // unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Cfp32_mkl.data(), + // m, n, n, "C mkl fp32"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint8_local.data(), m, n, n, "C requantized"); + // printMatrix(matrix_op_t::NoTranspose, col_offsets.data(), 1, n, n, "col + // offsets before"); + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithQuantRowOffset::rowOffsetBufferSize()); + + PackAWithQuantRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k, + Afp32.data(), + k, + nullptr, /*buffer for packed matrix*/ + Aint8_scale, + Aint8_zero_point, + 1, /*groups*/ + row_offset_buf.data()); + + PackBMatrix packedBN( + matrix_op_t::NoTranspose, + k, + n, + Bint8.data(), + n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing doNothingObj{}; + ReQuantizeForFloat outputProcObj( + doNothingObj, + Aint8_scale, + Bint8_scale, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + ttot = 0; + type = "FBGEMM_i8_acc32"; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + start = chrono::high_resolution_clock::now(); + fbgemmPacked( + packAN, + packedBN, + Cfp32_fb.data(), + (int32_t*)Cfp32_fb.data(), + n, + outputProcObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - start); + ttot += dur.count(); + } + } + ((volatile char*)(llc.data())); + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B + // unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(), + // m, n, n, "C requantized after"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint8_fb.data(), m, n, n, "C fb"); + // printMatrix(matrix_op_t::NoTranspose, + // col_offsets.data(), 1, n, n, "col offsets after"); + // compare_buffers(row_offsets.data(), row_offset_buf.data(), + // row_offsets.size(), 5); + // printMatrix(matrix_op_t::NoTranspose, Cfp32_fb.data(), + // m, n, n, "C fb fp32"); + cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k << ", " + << setw(16) << type << ", " << setw(5) << fixed << setw(5) + << setprecision(1) << nops / ttot << endl; + cout << endl; + // cout << "total time: " << ttot << " ns" << endl; + + float maximum = *max_element(Cfp32_mkl.begin(), Cfp32_mkl.end()); + float minimum = *min_element(Cfp32_mkl.begin(), Cfp32_mkl.end()); + float atol = (maximum - minimum) / 255 / 1.9; + +#ifdef USE_MKL + // correctness check + compare_buffers(Cfp32_mkl.data(), Cfp32_fb.data(), m, n, n, 5, atol); +#endif + } +} + +int main(int /* unused */, char** /* unused */) { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc new file mode 100644 index 0000000..a758f55 --- /dev/null +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -0,0 +1,439 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#ifdef USE_MKL +#include +#endif + +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +enum class BenchmarkType { + BARE_BONE, // no row-offset in input packing, and no output processing + REQUANTIZATION, // no row-offset in input packing, and requantization + ROW_OFFSET_AND_REQUANTIZATION, // row-offset in input packing, and + // requantization + EVERYTHING, // row-offset in input packing, and requantization + spmdm +}; + +void performance_test() { + vector> shapes = { + // m, n, k + {64, 68, 17}, {60, 128, 64}, {25088, 256, 64}, + {25088, 64, 64}, {25088, 64, 576}, {25088, 64, 256}, + + {6272, 512, 256}, {6272, 128, 256}, {6272, 128, 1152}, + {6272, 512, 128}, {6272, 128, 512}, + + {1568, 1024, 512}, {1568, 256, 512}, {1568, 256, 2304}, + {1568, 1024, 256}, {1568, 256, 1024}, + + {392, 2048, 1024}, {392, 512, 1024}, {392, 512, 4608}, + {392, 2048, 512}, {392, 512, 2048}, + }; + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "M, " + << "N, " + << "K, " + << "Output Processing, " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "Total (ms), " + << "GOPS" << endl; +#else + cout << setw(7) << "M, " << setw(7) << "N, " << setw(7) << "K, " << setw(32) + << "Output Processing, " << setw(18) << "Type, " << setw(5) << "GOPS" + << endl; +#endif + + chrono::time_point begin, end; + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + float alpha = 1.f, beta = 0.f; + aligned_vector Afp32(m * k, 0.0f); + aligned_vector Aint8(m * k, 0); + + aligned_vector Bfp32(k * n, 0.0f); + aligned_vector Bint8(k * n, 0); + + aligned_vector Cfp32_mkl(m * n, 0.0f); + // just used for result comparisons + aligned_vector Cint32_mkl(m * n, 0.0f); + // requantize results + aligned_vector Cint8_mkl(m * n, 0.0f); + aligned_vector Cint32_fb(m * n, 0.0f); + aligned_vector Cint8_fb(m * n, 0.0f); + + // A matrix + randFill(Afp32, 0, 50); + int32_t Aint8_zero_point = 43; + for (auto i = 0; i < Afp32.size(); ++i) { + Aint8[i] = static_cast(Afp32[i]); + } + + randFill(Bfp32, -8, 8); + + double nops = 2.0 * static_cast(NITER) * m * n * k; + double ttot = 0.0; + string runType; + +#ifdef USE_MKL + ttot = 0.0; + runType = "MKL_fp32"; + cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k + << ", "; + cout << setw(30) << "NA"; + cout << ", "; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + Afp32.data(), + k, + Bfp32.data(), + n, + beta, + Cfp32_mkl.data(), + n); + end = chrono::high_resolution_clock::now(); + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); + } + } + ((volatile char*)(llc.data())); + cout << setw(16) << runType << ", " << fixed << setw(5) << setprecision(1) + << nops / ttot << endl; + + for (auto i = 0; i < Cfp32_mkl.size(); ++i) { + Cint32_mkl[i] = static_cast(Cfp32_mkl[i]); + } +#endif + + for (BenchmarkType bench_type : + {BenchmarkType::BARE_BONE, + BenchmarkType::REQUANTIZATION, + BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION, + BenchmarkType::EVERYTHING}) { + // When we don't compute row_offset in fbgemm, we set B_zero_point to 0 + // to get the same result as the reference. + int32_t Bint8_zero_point = (bench_type == BenchmarkType::BARE_BONE || + bench_type == BenchmarkType::REQUANTIZATION) + ? 0 + : -30; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bint8[i] = static_cast(Bfp32[i]); + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n); + col_offsets_with_zero_pt_s8acc32_ref( + k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data()); + + vector row_offsets; + row_offsets.resize(m); + + row_offsets_u8acc32_ref(m, k, k, Aint8.data(), row_offsets.data()); + + float C_multiplier = + (bench_type == BenchmarkType::BARE_BONE) ? 1 : 0.1234; + int32_t C_zero_pt = (bench_type == BenchmarkType::BARE_BONE) ? 0 : 5; + + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, + // "B unpacked"); + // packedB.printPackedMatrix("B Packed"); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif + + cout << setw(5) << m << ", " << setw(5) << n << ", " << setw(5) << k + << ", "; + switch (bench_type) { + case BenchmarkType::BARE_BONE: + cout << setw(30) << "bare_bone"; + break; + case BenchmarkType::REQUANTIZATION: + cout << setw(30) << "requantization"; + break; + case BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION: + cout << setw(30) << "row_offset_and_requantization"; + break; + case BenchmarkType::EVERYTHING: + cout << setw(30) << "everything"; + break; + }; + cout << ", "; + + requantize_u8acc32_ref( + m, + n, + n, + Cint32_mkl.data(), + Cint8_mkl.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + nullptr); // bias + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithRowOffset::rowOffsetBufferSize()); + + PackAMatrix packA( + matrix_op_t::NoTranspose, + m, + k, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point); + PackAWithRowOffset packAWithRowOffset( + matrix_op_t::NoTranspose, + m, + k, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedB( + matrix_op_t::NoTranspose, k, n, Bint8.data(), n); + + // no-op output process objects + DoNothing doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + + // spmdm -> requantization -> nothing + // construct an output processing pipeline in reverse order + // i.e. last output operation first + // Last operation should always be DoNothing with + // correct input and output type. + DoNothing<> doNothingObj{}; + // Requantization back to int8 + ReQuantizeOutput reqObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + bench_type == BenchmarkType::REQUANTIZATION + ? nullptr + : packAWithRowOffset.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + CompressedSparseColumn B_csc(k, n); + + float density = 0.001f; + + // deterministic random number + default_random_engine eng; + binomial_distribution<> per_col_nnz_dist(k, density); + + vector row_indices(k); + + int total_nnz = 0; + for (int j = 0; j < n; ++j) { + B_csc.ColPtr()[j] = total_nnz; + + int nnz_of_j = per_col_nnz_dist(eng); + total_nnz += nnz_of_j; + + iota(row_indices.begin(), row_indices.end(), 0); + shuffle(row_indices.begin(), row_indices.end(), eng); + sort(row_indices.begin(), row_indices.begin() + nnz_of_j); + + for (int kidx = 0; kidx < nnz_of_j; ++kidx) { + B_csc.RowIdx().push_back(row_indices[kidx]); + // put the current B value + B_csc.Values().push_back(Bint8[row_indices[kidx] * n + j]); + // make current B value zero + Bint8[row_indices[kidx] * n + j] = 0; + // std::cout << "(" << row_indices[kidx] << ", " << j << ")" << + // endl; + } + } + B_csc.ColPtr()[n] = total_nnz; + + // the top most (first) operation in the output processing + // pipeline is spmdm + // outType = final output type after fullly processing through pipeline + // inType = initial input type at the first call to the whole pipeline + DoSpmdmOnInpBuffer< + ReQuantizeOutput::outType, + int32_t, + ReQuantizeOutput> + spmdmObj(reqObj, Aint8.data(), k, B_csc); + + // printMatrix(matrix_op_t::NoTranspose, + // Cint32_mkl.data(), m, n, n, "C mkl"); + ttot = 0; + runType = "FBGEMM_i8_acc16"; + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + switch (bench_type) { + case BenchmarkType::BARE_BONE: + fbgemmPacked( + packA, + packedB, + Cint32_fb.data(), + Cint32_fb.data(), + n, + memcopyObj, + 0, + 1); + break; + case BenchmarkType::REQUANTIZATION: + fbgemmPacked( + packA, + packedB, + Cint8_fb.data(), + Cint32_fb.data(), + n, + reqObj, + 0, + 1); + break; + case BenchmarkType::ROW_OFFSET_AND_REQUANTIZATION: + fbgemmPacked( + packAWithRowOffset, + packedB, + Cint8_fb.data(), + Cint32_fb.data(), + n, + reqObj, + 0, + 1); + break; + case BenchmarkType::EVERYTHING: + fbgemmPacked( + packAWithRowOffset, + packedB, + Cint8_fb.data(), + Cint32_fb.data(), + n, + spmdmObj, + 0, + 1); + break; + }; + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - begin); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + + ((volatile char*)(llc.data())); + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B + // unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(), + // m, n, n, "C requantized after"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint8_fb.data(), m, n, n, "C fb"); + // printMatrix(matrix_op_t::NoTranspose, + // col_offsets.data(), 1, n, n, "col offsets after"); + // compare_buffers(row_offsets.data(), row_offset_buf.data(), + // row_offsets.size(), 5); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", "; +#endif + cout << setw(16) << runType << ", " << fixed << setw(5) << setprecision(1) + << nops / ttot << endl; + +#ifdef USE_MKL + if (bench_type == BenchmarkType::BARE_BONE) { + compare_buffers(Cint32_mkl.data(), Cint32_fb.data(), m, n, n, 5); + } else { + compare_buffers(Cint8_mkl.data(), Cint8_fb.data(), m, n, n, 5); + } +#endif + } // test_outlier + cout << endl; + } // shapes +} + +int main() { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc new file mode 100644 index 0000000..27f1433 --- /dev/null +++ b/bench/PackedRequantizeAcc32Benchmark.cc @@ -0,0 +1,329 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#ifdef USE_MKL +#include +#endif + +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "test/QuantizationHelpers.h" +#include "BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +void performance_test() { + vector> shapes = { + {156800, 4, 36}, + {156800, 8, 36}, + {156800, 16, 36}, + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 4096, 1024}, + + {6, 256, 1024}, + {6, 256, 2048}, + {6, 512, 512}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 512}, + {6, 4096, 256}, + {6, 4096, 1024}, + {6, 4096, 2048}, + + {10, 2048, 256}, + {10, 4096, 1024}, + + {20, 2048, 256}, + {20, 4096, 1024}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + + {1, 800, 3200}, + {1, 800, 8000}, + + {16, 256, 1500}, + {16, 256, 1567}, + {1, 128, 2876}, + {16, 128, 1567}, + {1, 128, 2722}, + {16, 256, 512}, + }; + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "M, " + << "N, " + << "K, " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "Total (ms), " + << "GOPs" << endl; +#else + cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, " << setw(18) + << "Type, " << setw(5) << "GOPS" << endl; +#endif + + chrono::time_point start, end; + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + float alpha = 1.f, beta = 0.f; + aligned_vector Afp32(m * k, 0.0f); + aligned_vector Aint8(m * k, 0); + + aligned_vector Bfp32(k * n, 0.0f); + aligned_vector Bint8(k * n, 0); + + aligned_vector Cfp32_mkl(m * n, 0.0f); + aligned_vector Cint32_mkl(m * n, 0.0f); + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint32_local(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + aligned_vector Cint8_local(m * n, 0); + + // A matrix + randFill(Aint8, 0, 255); + // float Aint8_scale = 0.11; + int32_t Aint8_zero_point = 43; + for (auto i = 0; i < Afp32.size(); ++i) { + Afp32[i] = (float)Aint8[i]; + } + + randFill(Bint8, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bint8.data()); + + // float Bint8_scale = 0.49; + int32_t Bint8_zero_point = -30; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bfp32[i] = (float)Bint8[i]; + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n); + col_offsets_with_zero_pt_s8acc32_ref( + k, n, n, Bint8.data(), Bint8_zero_point, col_offsets.data()); + + double nops = 2.0 * static_cast(NITER) * m * n * k; + double ttot = 0.0; + string runType; +#ifdef USE_MKL + runType = "MKL_fp32"; + for (auto i = 0; i < NWARMUP + NITER; ++i) { + llc_flush(llc); + start = chrono::high_resolution_clock::now(); + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + Afp32.data(), + k, + Bfp32.data(), + n, + beta, + Cfp32_mkl.data(), + n); + end = chrono::high_resolution_clock::now(); + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - start); + ttot += dur.count(); + } + } + ((volatile char*)(llc.data())); + + cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k << ", " + << setw(16) << runType << ", " << setw(5) << fixed << setw(5) + << setprecision(1) << nops / ttot << endl; + + for (auto i = 0; i < Cfp32_mkl.size(); ++i) { + Cint32_mkl[i] = (int32_t)Cfp32_mkl[i]; + } +#endif + + vector row_offsets; + row_offsets.resize(m); + + float C_multiplier = 0.1234; + int32_t C_zero_pt = 5; + + matmul_u8i8acc32_ref( + m, n, k, k, n, n, Aint8.data(), Bint8.data(), Cint32_local.data()); + + row_offsets_u8acc32_ref(m, k, k, Aint8.data(), row_offsets.data()); + + requantize_u8acc32_ref( + m, + n, + n, + Cint32_local.data(), + Cint8_local.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + nullptr); // bias + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B + // unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_local.data(), + // m, n, n, "C int32"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint8_local.data(), m, n, n, "C requantized"); + // printMatrix(matrix_op_t::NoTranspose, col_offsets.data(), 1, n, n, "col + // offsets before"); + + vector row_offset_buf; + row_offset_buf.resize(PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedBN( + matrix_op_t::NoTranspose, + k, + n, + Bint8.data(), + n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + ttot = 0.0; + runType = "FBGEMM_i8_acc32"; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + start = chrono::high_resolution_clock::now(); + fbgemmPacked( + packAN, + packedBN, + Cint8_fb.data(), + Cint32_buffer.data(), + n, + outputProcObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - start); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + ((volatile char*)(llc.data())); + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), k, n, n, "B + // unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Aint8.data(), m, k, k, + // "A unpacked"); + // printMatrix(matrix_op_t::NoTranspose, Cint8_local.data(), + // m, n, n, "C requantized after"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint8_fb.data(), m, n, n, "C fb"); + // printMatrix(matrix_op_t::NoTranspose, + // col_offsets.data(), 1, n, n, "col offsets after"); + // compare_buffers(row_offsets.data(), row_offset_buf.data(), + // row_offsets.size(), 5); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", "; +#endif + cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k << ", " + << setw(16) << runType << ", " << setw(5) << fixed << setw(5) + << setprecision(1) << nops / ttot << endl; + cout << endl; + +#ifdef USE_MKL + compare_buffers(Cint8_local.data(), Cint8_fb.data(), m, n, n, 5); +#endif + } +} + +int main(int /* unused */, char** /* unused */) { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/cmake/modules/DownloadASMJIT.cmake b/cmake/modules/DownloadASMJIT.cmake new file mode 100644 index 0000000..6e6a6f9 --- /dev/null +++ b/cmake/modules/DownloadASMJIT.cmake @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +project(asmjit-download NONE) + +include(ExternalProject) + +ExternalProject_Add(asmjit + GIT_REPOSITORY https://github.com/asmjit/asmjit + GIT_TAG master + SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/asmjit" + BINARY_DIR "${FBGEMM_BINARY_DIR}/asmjit" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/modules/DownloadCPUINFO.cmake b/cmake/modules/DownloadCPUINFO.cmake new file mode 100644 index 0000000..730ecbf --- /dev/null +++ b/cmake/modules/DownloadCPUINFO.cmake @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +project(cpuinfo-download NONE) + +include(ExternalProject) + +ExternalProject_Add(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo + GIT_TAG master + SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" + BINARY_DIR "${FBGEMM_BINARY_DIR}/cpuinfo" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/modules/DownloadGTEST.cmake b/cmake/modules/DownloadGTEST.cmake new file mode 100644 index 0000000..b0a56f8 --- /dev/null +++ b/cmake/modules/DownloadGTEST.cmake @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +project(googletest-download NONE) + +include(ExternalProject) + +ExternalProject_Add(googletest + GIT_REPOSITORY https://github.com/google/googletest + GIT_TAG 0fc5466dbb9e623029b1ada539717d10bd45e99e + SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/googletest" + BINARY_DIR "${FBGEMM_BINARY_DIR}/googletest" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/modules/FindMKL.cmake b/cmake/modules/FindMKL.cmake new file mode 100644 index 0000000..6b38af7 --- /dev/null +++ b/cmake/modules/FindMKL.cmake @@ -0,0 +1,270 @@ +# - Find INTEL MKL library +# This module finds the Intel Mkl libraries. +# Note: This file is a modified version of pytorch/cmake/Modules/FindMKL.cmake +# +# This module sets the following variables: +# MKL_FOUND - set to true if a library implementing the CBLAS interface is found +# MKL_VERSION - best guess +# MKL_INCLUDE_DIR - path to include dir. +# MKL_LIBRARIES - list of libraries for base mkl + +# Do nothing if MKL_FOUND was set before! +IF (NOT MKL_FOUND) + +SET(MKL_VERSION) +SET(MKL_INCLUDE_DIR) +SET(MKL_LIBRARIES) + +# Includes +INCLUDE(CheckTypeSize) +INCLUDE(CheckFunctionExists) + +# Intel Compiler Suite +SET(INTEL_COMPILER_DIR "/opt/intel" CACHE STRING + "Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)") +SET(INTEL_MKL_DIR "/opt/intel/mkl" CACHE STRING + "Root directory of the Intel MKL (standalone)") +SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL + "Force using the sequential (non threaded) libraries") + +# Checks +CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP) +IF ("${SIZE_OF_VOIDP}" EQUAL 8) + SET(mklvers "intel64") + SET(iccvers "intel64") + SET(mkl64s "_lp64") +ELSE ("${SIZE_OF_VOIDP}" EQUAL 8) + SET(mklvers "32") + SET(iccvers "ia32") + SET(mkl64s) +ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8) +IF(CMAKE_COMPILER_IS_GNUCC) + SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread") + SET(mklifaces "intel" "gf") + SET(mklrtls "gomp" "iomp5") +ELSE(CMAKE_COMPILER_IS_GNUCC) + SET(mklthreads "mkl_intel_thread") + SET(mklifaces "intel") + SET(mklrtls "iomp5" "guide") + IF (MSVC) + SET(mklrtls "libiomp5md") + ENDIF (MSVC) +ENDIF (CMAKE_COMPILER_IS_GNUCC) + +# Kernel libraries dynamically loaded +SET(mklseq) + + +# Paths +SET(saved_CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}) +SET(saved_CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}) +IF(WIN32) + # Set default MKLRoot for Windows + IF($ENV{MKLProductDir}) + SET(INTEL_COMPILER_DIR $ENV{MKLProductDir}) + ELSE() + SET(INTEL_COMPILER_DIR + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows") + ENDIF() + # Change mklvers and iccvers when we are using MSVC instead of ICC + IF(MSVC AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + SET(mklvers "${mklvers}_win") + SET(iccvers "${iccvers}_win") + ENDIF() +ENDIF(WIN32) +IF (EXISTS ${INTEL_COMPILER_DIR}) + # TODO: diagnostic if dir does not exist + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_COMPILER_DIR}/lib/${iccvers}") + IF(MSVC) + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_COMPILER_DIR}/compiler/lib/${iccvers}") + ENDIF() + IF (NOT EXISTS ${INTEL_MKL_DIR}) + SET(INTEL_MKL_DIR "${INTEL_COMPILER_DIR}/mkl") + ENDIF() +ENDIF() +IF (EXISTS ${INTEL_MKL_DIR}) + # TODO: diagnostic if dir does not exist + SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${INTEL_MKL_DIR}/include") + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_MKL_DIR}/lib/${mklvers}") + IF (MSVC) + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_MKL_DIR}/lib/${iccvers}") + ENDIF() +ENDIF() + +# Try linking multiple libs +MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags) + # This macro checks for the existence of the combination of libraries given by _list. + # If the combination is found, this macro checks whether we can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + SET(_prefix "${LIBRARIES}") + # start checking + SET(_libraries_work TRUE) + SET(${LIBRARIES}) + SET(_combined_name) + SET(_paths) + SET(__list) + FOREACH(_elem ${_list}) + IF(__list) + SET(__list "${__list} - ${_elem}") + ELSE(__list) + SET(__list "${_elem}") + ENDIF(__list) + ENDFOREACH(_elem) + MESSAGE(STATUS "Checking for [${__list}]") + FOREACH(_library ${_list}) + SET(_combined_name ${_combined_name}_${_library}) + IF(_libraries_work) + IF(${_library} STREQUAL "gomp") + FIND_PACKAGE(OpenMP) + IF(OPENMP_FOUND) + SET(${_prefix}_${_library}_LIBRARY ${OpenMP_C_FLAGS}) + ENDIF(OPENMP_FOUND) + ELSE(${_library} STREQUAL "gomp") + FIND_LIBRARY(${_prefix}_${_library}_LIBRARY NAMES ${_library}) + ENDIF(${_library} STREQUAL "gomp") + MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY) + SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + SET(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + IF(${_prefix}_${_library}_LIBRARY) + MESSAGE(STATUS " Library ${_library}: ${${_prefix}_${_library}_LIBRARY}") + ELSE(${_prefix}_${_library}_LIBRARY) + MESSAGE(STATUS " Library ${_library}: not found") + ENDIF(${_prefix}_${_library}_LIBRARY) + ENDIF(_libraries_work) + ENDFOREACH(_library ${_list}) + # Test this combination of libraries. + IF(_libraries_work) + SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}}) + SET(CMAKE_REQUIRED_LIBRARIES "${CMAKE_REQUIRED_LIBRARIES};${CMAKE_REQUIRED_LIBRARIES}") + CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS) + SET(CMAKE_REQUIRED_LIBRARIES) + MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS) + SET(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + ENDIF(_libraries_work) + # Fin + IF(_libraries_work) + ELSE (_libraries_work) + SET(${LIBRARIES}) + MARK_AS_ADVANCED(${LIBRARIES}) + ENDIF(_libraries_work) +ENDMACRO(CHECK_ALL_LIBRARIES) + +IF(WIN32) + SET(mkl_m "") + SET(mkl_pthread "") +ELSE(WIN32) + SET(mkl_m "m") + SET(mkl_pthread "pthread") +ENDIF(WIN32) + +IF(UNIX AND NOT APPLE) + SET(mkl_dl "${CMAKE_DL_LIBS}") +ELSE(UNIX AND NOT APPLE) + SET(mkl_dl "") +ENDIF(UNIX AND NOT APPLE) + +# Check for version 10/11 +IF (NOT MKL_LIBRARIES) + SET(MKL_VERSION 1011) +ENDIF (NOT MKL_LIBRARIES) +FOREACH(mklrtl ${mklrtls} "") + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + FOREACH(mklthread ${mklthreads}) + IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};${mkl_pthread};${mkl_m};${mkl_dl}" "") + ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL) + ENDFOREACH(mklthread) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) +FOREACH(mklrtl ${mklrtls} "") + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + IF (NOT MKL_LIBRARIES) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};mkl_sequential;mkl_core;${mkl_m};${mkl_dl}" "") + IF (MKL_LIBRARIES) + SET(mklseq "_sequential") + ENDIF (MKL_LIBRARIES) + ENDIF (NOT MKL_LIBRARIES) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) +FOREACH(mklrtl ${mklrtls} "") + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + FOREACH(mklthread ${mklthreads}) + IF (NOT MKL_LIBRARIES) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "") + ENDIF (NOT MKL_LIBRARIES) + ENDFOREACH(mklthread) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) + +# Check for older versions +IF (NOT MKL_LIBRARIES) + SET(MKL_VERSION 900) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl;guide;pthread;m" "") +ENDIF (NOT MKL_LIBRARIES) + +# Include files +IF (MKL_LIBRARIES) + FIND_PATH(MKL_INCLUDE_DIR "mkl_cblas.h") + MARK_AS_ADVANCED(MKL_INCLUDE_DIR) +ENDIF (MKL_LIBRARIES) + +# LibIRC: intel compiler always links this; +# gcc does not; but mkl kernels sometimes need it. +IF (MKL_LIBRARIES) + IF (CMAKE_COMPILER_IS_GNUCC) + FIND_LIBRARY(MKL_KERNEL_libirc "irc") + ELSEIF (CMAKE_C_COMPILER_ID AND NOT CMAKE_C_COMPILER_ID STREQUAL "Intel") + FIND_LIBRARY(MKL_KERNEL_libirc "irc") + ENDIF (CMAKE_COMPILER_IS_GNUCC) + MARK_AS_ADVANCED(MKL_KERNEL_libirc) + IF (MKL_KERNEL_libirc) + SET(MKL_LIBRARIES ${MKL_LIBRARIES} ${MKL_KERNEL_libirc}) + ENDIF (MKL_KERNEL_libirc) +ENDIF (MKL_LIBRARIES) + +# Final +SET(CMAKE_LIBRARY_PATH ${saved_CMAKE_LIBRARY_PATH}) +SET(CMAKE_INCLUDE_PATH ${saved_CMAKE_INCLUDE_PATH}) +IF (MKL_LIBRARIES AND MKL_INCLUDE_DIR) + SET(MKL_FOUND TRUE) + set(MKL_cmake_included true) +ELSE (MKL_LIBRARIES AND MKL_INCLUDE_DIR) + SET(MKL_FOUND FALSE) + SET(MKL_VERSION) +ENDIF (MKL_LIBRARIES AND MKL_INCLUDE_DIR) + +# Standard termination +IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) + MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location") +ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) +IF(NOT MKL_FIND_QUIETLY) + IF(MKL_FOUND) + MESSAGE(STATUS "MKL library found") + ELSE(MKL_FOUND) + MESSAGE(STATUS "MKL library not found") + return() + ENDIF(MKL_FOUND) +ENDIF(NOT MKL_FIND_QUIETLY) + +# Do nothing if MKL_FOUND was set before! +ENDIF (NOT MKL_FOUND) diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h new file mode 100644 index 0000000..02e862f --- /dev/null +++ b/include/fbgemm/ConvUtils.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include + +namespace fbgemm2 { + +/** + * @brief A struct to conveniently store all convolution parameters. + */ +struct conv_param_t { + int MB; ///< Mini Batch size + int IC; ///< Number of Input Channels + int OC; ///< Number of Output Channels + int IH; ///< Input Image Height + int IW; ///< Input Image Width + int G; ///< Number of Groups + int KH; ///< Filter (Kernel) Height + int KW; ///< Filter (Kernel) Width + int stride_h; ///< Stride in Height Dimension + int stride_w; ///< Stride in Width Dimension + int pad_h; ///< Padding in Height Dimension (top and bottom) + int pad_w; ///< Padding in Width Dimension (left and right) + int dilation_h; ///< Kernel dilation in Height Dimension + int dilation_w; ///< Kernel dilation in Width Dimension + + // The following are derived parameters + int OH; ///< Output Image Height + int OW; ///< Output Image Width + int IHP; ///< Input Height Padded + int IWP; ///< Input Width Padded + + /** + * @brief Constructor for initializing the convolution parameters. + * TODO: Dilation is not handled correctly. + */ + conv_param_t( + int mb, + int ic, + int oc, + int ih, + int iw, + int g = 1, + int kh = 3, + int kw = 3, + int strd_h = 1, + int strd_w = 1, + int pd_h = 1, + int pd_w = 1) + : MB(mb), + IC(ic), + OC(oc), + IH(ih), + IW(iw), + G(g), + KH(kh), + KW(kw), + stride_h(strd_h), + stride_w(strd_w), + pad_h(pd_h), + pad_w(pd_w), + dilation_h(1), + dilation_w(1) { + IHP = IH + 2 * pad_h; + IWP = IW + 2 * pad_w; + OH = (IHP - KH) / stride_h + 1; + OW = (IWP - KW) / stride_w + 1; + } + + /** + * @brief Helper function to get convolution parameters as string. + */ + std::string toString() const { + std::string out = ""; + out += "MB:" + std::to_string(MB) + ", "; + out += "IC:" + std::to_string(IC) + ", "; + out += "OC:" + std::to_string(OC) + ", "; + out += "IH:" + std::to_string(IH) + ", "; + out += "IW:" + std::to_string(IW) + ", "; + out += "G:" + std::to_string(G) + ", "; + out += "KH:" + std::to_string(KH) + ", "; + out += "KW:" + std::to_string(KW) + ", "; + out += "stride_h:" + std::to_string(stride_h) + ", "; + out += "stride_w:" + std::to_string(stride_w) + ", "; + out += "pad_h:" + std::to_string(pad_h) + ", "; + out += "pad_w:" + std::to_string(pad_w); + return out; + } +}; + +} // namespace fbgemm2 diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h new file mode 100644 index 0000000..988e24b --- /dev/null +++ b/include/fbgemm/Fbgemm.h @@ -0,0 +1,952 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +/** + * Top level include file for FBGEMM. + */ +#include +#include +#include +#include +#include +#include +#include "ConvUtils.h" +#include "FbgemmI8Spmdm.h" +#include "Types.h" +#include "Utils.h" + +// #define FBGEMM_MEASURE_TIME_BREAKDOWN + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +#include +#include +extern double packing_time; +extern double computing_time; +extern double kernel_time; +extern double postprocessing_time; +extern double run_time; +#endif + +namespace fbgemm2 { + +/** + * @brief Templatized struct for packing parameters for A and B matrices. + * + * @tparam T input type + * @tparam accT the type used for accumulation + * @tparam instSet anyarch/avx2/avx512 + * @tparam int8Type an auxiliary template parameter to specialize for 8-bit + * input types. + */ +template < + typename T, + typename accT, + inst_set_t instSet, + typename int8Type = void> +struct PackingTraits; + +// type specialized implementation in an include file +#include "PackingTraits-inl.h" + +/** + * @brief Base class for packing matrices for higher GEMM performance. + * + * Matrix is tiled into blockRows() * blockCols() blocks. + * Each block is with size blockRowSize() * blockColSize(). + * This class is designed using CRTP + * (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) + * + * @tparam PT actual packing type, e.g., PackAWithRowOffset + */ +template +class PackMatrix { + public: + PackMatrix() = delete; // no default constructor + + /** + * @param rows total number of rows in the matrix + * (packed rows can be less than rows). + * @param cols total number of columns in the matrix + * @param pmat A buffer to contain the packed matrix. + * If nullptr, a buffer owned by PackMatrix will be allocated + * internally to contain the packed matrix. + * For non-constant matrices like activation matrices, the client + * code may want to pass a pre-allocated pmat to avoid the + * overhead of internal memory allocation everytime a PackMatrix + * is constructed. The client code can query how big patm should + * be with packedBufferSize function. + * @param zero_pt the quantized value that maps to 0.0f floating-point number. + */ + PackMatrix( + std::int32_t rows, + std::int32_t cols, + inpType* pmat, + std::int32_t zero_pt); + + /** + * @return true usually when the matrix is constant matrix (e.g., weight + * matrices) that can be prepacked + */ + bool isPrePacked() const { + return static_cast(this)->isPrePacked(); + } + + /** + * @return true if this is the first input matrix in GEMM (i.e., A in C = A * + * B) + */ + static constexpr bool isA() { + return PT::isA(); + } + + /** + * @brief The size of the buffer used for packing (The size is in number of + * elements). + * + * rows and cols are only used for fully packing, i.e., for B matrix. The + * client code can use this function to query how big the buffer used for + * packing should be. + */ + static int packedBufferSize(int rows = 0, int cols = 0); + + /** + * @return Pointer to a buffer containing row offset results. Some packing + * objects fuse row offset computation for later requantization step. + */ + std::int32_t* getRowOffsetBuffer() const { + return static_cast(this)->getRowOffsetBuffer(); + } + + /** + * @brief When k loop is also tiled/blocked, this function is used to check if + * have executed computations for the last k block so that we can perform + * post-GEMM operations. + */ + bool isThisLastKBlock(int block_id) const { + return static_cast(this)->isThisLastKBlock(block_id); + } + + /** + * @brief Actual packing of a block of the source matrix in pmat buffer. + */ + void pack(const block_type_t& block) { + static_cast(this)->pack(block); + } + + std::int32_t numRows() const { + return nrows_; + } + + std::int32_t numCols() const { + return ncols_; + } + + /** + * @return The number of rows in each block + */ + std::int32_t blockRowSize() const { + return brow_; + } + + /** + * @return The number of columns in each block + */ + std::int32_t blockColSize() const { + return bcol_; + } + + /** + * @return The number of blocks along rows + */ + std::int32_t blockRows() const { + return nbrow_; + } + + /** + * @return The number of blocks along columns + */ + std::int32_t blockCols() const { + return nbcol_; + } + + /** + * @return The number of the rows in the currently packed block of a matrix. + * For pre-packed (i.e., fully-packed), it's equal to the total number + * of rows. + */ + std::int32_t numPackedRows() const { + return packedBlock_.row_size; + } + + /** + * @return The number of columns in the currently packed block of a matrix. + * For pre-packed (i.e., fully-packed), it's equal to the number of + * columns. + */ + std::int32_t numPackedCols() const { + return packedBlock_.col_size; + } + + /** + * @return The first row of the block we're working on. + */ + std::int32_t packedRowStart() const { + return packedBlock_.row_start; + } + + /** + * @return The beginning of (rowBlockNum, colBlockNum)th block + */ + inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) { + return buf_ + blockRowSize() * blockColSize() * rowBlockNum + + blockRowSize() * blockColSize() * blockCols() * colBlockNum; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name) { + static_cast(this)->printPackedMatrix(name); + } + + /** + * @return The number of rows in the last row block. + */ + std::int32_t lastBrow() const { + return last_brow_; + } + + /** + * @return The number of columns in the last column block. + */ + std::int32_t lastBcol() const { + return last_bcol_; + } + + /** + * @return True if the last column block has fewer columns than the block + * size. + */ + bool isThereColRemainder() const { + return last_bcol_ != blockColSize(); + } + + ~PackMatrix() { + if (bufAllocatedHere_) { + free(buf_); + } + } + + protected: + /** + * Set which block we're packing + */ + void packedBlock(const block_type_t& block) { + packedBlock_ = block; + nbrow_ = (numPackedRows() + blockRowSize() - 1) / blockRowSize(); + nbcol_ = (numPackedCols() + blockColSize() - 1) / blockColSize(); + + last_brow_ = ((numPackedRows() % blockRowSize()) == 0) + ? blockRowSize() + : (numPackedRows() % blockRowSize()); + last_bcol_ = ((numPackedCols() % blockColSize()) == 0) + ? blockColSize() + : (numPackedCols() % blockColSize()); + } + + /** + * @return the quantized value that maps to 0.0f floating-point number + */ + std::int32_t zeroPoint() const { + return zero_pt_; + } + + inpType* buf_; + std::int32_t brow_; ///< the number of rows in each block + std::int32_t bcol_; ///< the number of columns in each block + std::int32_t nbrow_; ///< the number of blocks along rows + std::int32_t nbcol_; ///< the number of blocks along columns + bool bufAllocatedHere_; + + private: + std::int32_t nrows_, ncols_; + std::int32_t zero_pt_; + block_type_t packedBlock_; ///< The block in the source matrix just packed + std::int32_t last_brow_, last_bcol_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually + * activation). The source matrix is already quantized. Default + * accumulation type is int32. + */ +template +class PackAMatrix : public PackMatrix, T, accT> { + public: + using This = PackAMatrix; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAMatrix() = delete; // no default constructor + + /** + * TODO: currently only groups == 1 supported + */ + PackAMatrix( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const inpType* smat, + std::int32_t ld, + inpType* pmat = nullptr, + std::int32_t groups = 1, + accT zero_pt = 0); + + /** + * Activation matrices are not constant so cannot amortize the cost of + * pre-packing. + */ + bool isPrePacked() const { + return false; + } + + /** + * @return True is this is used as A matrix. + */ + static constexpr bool isA() { + return true; + } + + /** + * @return A pointer to the row offset buffer. There is no row offset buffer + * calculations with this packing class, hence, it returns nullptr. + */ + std::int32_t* getRowOffsetBuffer() const { + return nullptr; + } + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix. + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name); + + private: + matrix_op_t trans_; + const T* smat_; + std::int32_t ld_; + std::int32_t G_; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the second input matrix in GEMM (usually weight). + * The source matrix is already quantized. Default accumulation + * type is int32. + */ +template +class PackBMatrix : public PackMatrix, T, accT> { + public: + using This = PackBMatrix; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackBMatrix() = delete; // no default constructor + + /** + * TODO: Currently only groups == 1 supported. + */ + PackBMatrix( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const inpType* smat, + std::int32_t ld, + inpType* pmat = nullptr, + std::int32_t groups = 1, + accT zero_pt = 0); + + /** + * Weight matrices are usually constant so worth pre-packing. + */ + bool isPrePacked() const { + return true; + } + + /** + * @return True if to be used as A matrix, False otherwise. + */ + static constexpr bool isA() { + return false; + } + + /** + * @brief When k loop is also tiled/blocked, this function is used to check if + * have executed computations for the last k block so that we can perform + * post-GEMM operations. + */ + bool isThisLastKBlock(int block_id) const { + return (BaseType::blockRows() - 1) == block_id; + } + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix. + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name); + + ~PackBMatrix() {} + + private: + matrix_op_t trans_; + const T* smat_; + std::int32_t ld_; + std::int32_t G_; + std::int32_t row_interleave_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * Im2col is fused with packing here. The source matrix is already + * quantized. + */ +template +class PackAWithIm2Col : public PackMatrix, T, accT> { + public: + using This = PackAWithIm2Col; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithIm2Col() = delete; // no default constructor + /** + * TODO: Currently only groups == 1 supported + */ + PackAWithIm2Col( + const conv_param_t& conv_param, + const T* sdata, + inpType* pmat = nullptr, + std::int32_t zero_pt = 0, + std::int32_t* row_offset = nullptr); + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name); + + /** + * @return Size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(); + + ~PackAWithIm2Col() { + if (rowOffsetAllocatedHere) { + free(row_offset_); + } + } + + private: + const conv_param_t& conv_p_; + const T* sdata_; + std::int32_t* row_offset_; + bool rowOffsetAllocatedHere; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * The source matrix is already quantized. + */ +template +class PackAWithRowOffset + : public PackMatrix, T, accT> { + public: + using This = PackAWithRowOffset; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithRowOffset() = delete; // no default constructor + /** + * TODO: Currently only groups == 1 supported + */ + PackAWithRowOffset( + matrix_op_t trans, + std::uint32_t nRow, + std::uint32_t nCol, + const T* smat, + std::uint32_t ld, + inpType* pmat = nullptr, + std::uint32_t groups = 1, + std::int32_t zero_pt = 0, + std::int32_t* row_offset = nullptr); + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name); + + /** + * @return size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(); + + ~PackAWithRowOffset() { + if (rowOffsetAllocatedHere) { + free(row_offset_); + } + } + + private: + matrix_op_t trans_; + const T* smat_; + std::uint32_t ld_; + std::uint32_t G_; + std::int32_t* row_offset_; + bool rowOffsetAllocatedHere; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * The source matrix is in fp32 and quantized during packing. + */ +template +class PackAWithQuantRowOffset + : public PackMatrix, T, accT> { + public: + using This = PackAWithQuantRowOffset; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithQuantRowOffset() = delete; // no default constructor + /** + * TODO: Currently only groups == 1 supported + */ + PackAWithQuantRowOffset( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const float* smat, + std::int32_t ld, + inpType* pmat = nullptr, + float scale = 1.0f, + std::int32_t zero_pt = 0, + std::int32_t groups = 1, + std::int32_t* row_offset = nullptr); + + /** + * @return offset of the element in the packed matrix that was at (i, j) in + * the source matrix + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(std::string name); + + /** + * @return Size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(); + + ~PackAWithQuantRowOffset() { + if (rowOffsetAllocatedHere) { + free(row_offset_); + } + } + + private: + matrix_op_t trans_; + const float* smat_; + std::int32_t ld_; + float scale_; + std::int32_t G_; + std::int32_t* row_offset_; + bool rowOffsetAllocatedHere; + std::int32_t row_interleave_B_; +}; + +/* + * + * Post Processing of outputs + * + */ + +/** + * @brief Does nothing. NoOp. Used as the last operation in the output + * processing pipeline. + * + */ +template +class DoNothing { + public: + using outType = outT; + using inpType = inT; + DoNothing() {} + template + int f( + outType* /* unused */, + inpType* /* unused */, + const block_type_t& /* unused */, + int /* unused */, + int /* unused */) const { + return 0; + } +}; + +/** + * @brief Copy data pointed by inp ptr to out ptr when + * inp ptr and out ptr are not the same. + * inp buffer: row and column start points: (0, 0) + * output buffer: row and column start points: + * (block.row_start, block.col_start) + * + * This is the output processing stage that should passed when there is no + * requantization and output is required in the same format as internal buffer + * used for accumulation. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class memCopy { + public: + using outType = outT; + using inpType = inT; + explicit memCopy(nextOPType& nextop) : nextop_(nextop) {} + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; +}; + +/** + * @brief Perform scaling on accumulated data. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ScaleOP { + public: + using outType = outT; + using inpType = inT; + explicit ScaleOP(inpType scalingFactor) : scalingFactor_(scalingFactor) {} + + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + inpType scalingFactor_; +}; + +/** + * @brief Perform Relu on accumulated data. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ReluOutput { + public: + using outType = outT; + using inpType = inT; + explicit ReluOutput(inpType zero_pt) : zero_pt_(zero_pt) {} + + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + inpType zero_pt_; +}; + +/** + * @brief Perform Sparse-Matrix * Dense-Matrix as a part the of output + * processing pipeline. + * + * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer + * (inp). After modifying the input buffer, pass it to the next op + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class DoSpmdmOnInpBuffer { + public: + using outType = outT; + using inpType = inT; + DoSpmdmOnInpBuffer( + nextOPType& nextop, + const std::uint8_t* A, + int lda, + const CompressedSparseColumn& B_csc) + : nextop_(nextop), A_(A), lda_(lda), B_csc_(B_csc) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + const std::uint8_t* A_; + const int lda_; + const CompressedSparseColumn& B_csc_; +}; + +/** + * @brief Requantize values in inp buffer and write to out buffer. + * pass the out buffer to next op for further processing. + * + */ +template < + bool FUSE_RELU, + typename outT = std::uint8_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ReQuantizeOutput { + public: + using outType = outT; + using inpType = inT; + ReQuantizeOutput( + nextOPType& nextop, + float C_multiplier, + std::int32_t C_zero_point, + std::int32_t Aq_zero_point, + std::int32_t Bq_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const std::int32_t* bias) + : nextop_(nextop), + C_multiplier_(C_multiplier), + C_zero_point_(C_zero_point), + Aq_zero_point_(Aq_zero_point), + Bq_zero_point_(Bq_zero_point), + q_row_offsets_(row_offsets), + q_col_offsets_(col_offsets), + bias_(bias) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + float C_multiplier_; + std::int32_t C_zero_point_; + std::int32_t Aq_zero_point_; + std::int32_t Bq_zero_point_; + const std::int32_t* q_row_offsets_; + const std::int32_t* q_col_offsets_; + const std::int32_t* bias_; +}; + +/** + * @brief Requantize to convert accumulated data to be used as float, i.e., the + * output would be used as float. + */ +template < + bool FUSE_RELU, + typename outT = float, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ReQuantizeForFloat { + public: + using outType = outT; + using inpType = inT; + ReQuantizeForFloat( + nextOPType& nextop, + float Aq_scale, + float Bq_scale, + std::int32_t Aq_zero_point, + std::int32_t Bq_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const float* bias) + : nextop_(nextop), + Aq_scale_(Aq_scale), + Bq_scale_(Bq_scale), + Aq_zero_point_(Aq_zero_point), + Bq_zero_point_(Bq_zero_point), + q_row_offsets_(row_offsets), + q_col_offsets_(col_offsets), + bias_(bias) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + float Aq_scale_, Bq_scale_; + std::int32_t Aq_zero_point_; + std::int32_t Bq_zero_point_; + const std::int32_t* q_row_offsets_; + const std::int32_t* q_col_offsets_; + const float* bias_; +}; + +// type specialized implementation in an include file +#ifdef __AVX2__ +#include "OutputProcessing-inl.h" +#endif + +/* + * + * ####### GEMM related functions ####### + * + */ + +/** + * Matrix B must be prepacked. For matrix A, packA.pack function is called to + * pack it. + * + * @tparam packingAMatrix processing of A matrix while packing, + * e.g., PackAWithQuantRowOffset + * + * @tparam packingBMatrix processing of B matrix while packing, + * e.g., pre-multiply by alpha + * @tparam cT data type of C matrix + * @tparam processOutputType further processing of outputs, e.g., Relu + */ +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +void fbgemmPacked( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + cT* C, + std::int32_t* C_buffer, + std::uint32_t ldc, + const processOutputType& outProcess, + int thread_id, + int num_threads); + +/** + * @brief Perform depthwise separable convolution + */ + +template < + typename packingAMatrix, + typename packingBMatrix, + typename outT, + typename processOutputType> +void convDepthwiseSeparable( + const conv_param_t& conv_param_dw, + const conv_param_t& conv_param_1x1, + packingAMatrix& packdw, + packingBMatrix& packed_1x1, + outT* out, + const processOutputType& output); + +} // namespace fbgemm2 diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h new file mode 100644 index 0000000..55718d4 --- /dev/null +++ b/include/fbgemm/FbgemmFP16.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +// WARNING: this is a legacy fp16 fbgemm implementation and will soon be +// upgraded to match with new fbgemm interface. + +#include +#include +#include + +#include "Types.h" +#include "Utils.h" + +namespace fbgemm2 { + +/// class that performs packing of matrix in +/// row-major format into +/// internal packed blocked-row major format +class PackedGemmMatrixFP16 { +public: + // takes smat input mamtrix in row-major format; + // and packs it into gemm-friendly blocked format; + // allocate space and sets up all the internal variables; + // also premultiplies by alpha during packing + // brow_ contains tile size along k dimension + // and also is # of fmas updates into int16 container + // before flushing into fp32 + // the smaller the brow_, the higher overhead + // of flushing is + PackedGemmMatrixFP16(const matrix_op_t trans, const int nrow, + const int ncol, const float alpha, + const float *smat, + const int brow = 512) + : nrow_(nrow), ncol_(ncol), brow_(brow) { + + bcol_ = 8 * 1; // hardwired + + // set up internal packing parameters + nbrow_ = ((numRows() % blockRowSize()) == 0) + ? (numRows() / blockRowSize()) + : ((numRows() + blockRowSize()) / blockRowSize()); + last_brow_ = ((nrow % blockRowSize()) == 0) ? blockRowSize() + : (nrow % blockRowSize()); + nbcol_ = ((numCols() % blockColSize()) == 0) + ? (numCols() / blockColSize()) + : ((numCols() + blockColSize()) / blockColSize()); + + if (numCols() != blockColSize() * nbcol_) { +#ifdef VLOG + VLOG(0) + << "Packer warning: ncol(" << numCols() + << ") is not a multiple of internal block size (" << blockColSize() + << ")"; + VLOG(0) + << "lefover is currently done via MKL: hence overhead will inccur"; +#endif + } + + // allocate and initialize packed memory + const int padding = 1024; // required by sw pipelined kernels + size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_); + pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + padding); + for (auto i = 0; i < matSize(); i++) { + pmat_[i] = tconv(0.f, pmat_[i]); + } + + // copy source matrix into packed matrix + this->packFromSrc(trans, alpha, smat); + } + + ~PackedGemmMatrixFP16() { + free(pmat_); + } + +// protected: + // blocked row-major format address arithmetic + uint64_t addr(const int r_, const int c_) const { + uint64_t r = (uint64_t)r_; + uint64_t c = (uint64_t)c_; + + uint64_t block_row_id = r / blockRowSize(), + brow_offset = + (block_row_id * nbcol_) * (blockRowSize() * blockColSize()); + uint64_t block_col_id = c / blockColSize(), + bcol_offset = + block_col_id * ((block_row_id != nbrow_ - 1) + ? (blockRowSize() * blockColSize()) + : (last_brow_ * blockColSize())); + uint64_t block_offset = brow_offset + bcol_offset; + uint64_t inblock_offset = + r % blockRowSize() * blockColSize() + c % blockColSize(); + + uint64_t index = block_offset + inblock_offset; + assert(index < matSize()); + return index; + } + + void packFromSrc(const matrix_op_t trans, const float alpha, + const float *smat) { + bool tr = (trans == matrix_op_t::Transpose); + // pack + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + pmat_[addr(i, j)] = tconv( + alpha * ( + (tr == false) + ? smat[i * numCols() + j] : smat[i + numRows() * j]), + pmat_[addr(i, j)]); + } + } + } + + const float16 &operator()(const int r, const int c) const { + uint64_t a = addr(r, c); + assert(r < numRows()); + assert(c < numCols()); + assert(a < this->matSize()); + return pmat_[a]; + } + + int matSize() const { return size_; } + int numRows() const { return nrow_; } + int numCols() const { return ncol_; } + inline int blockRowSize() const { return brow_; } + inline int blockColSize() const { return bcol_; } + + int nrow_, ncol_; + int brow_, last_brow_, bcol_; + int nbrow_, nbcol_; + uint64_t size_; + float16 *pmat_; + + friend void cblas_gemm_compute(const matrix_op_t transa, const int m, + const float *A, + const PackedGemmMatrixFP16 &Bp, + const float beta, float *C); + friend void cblas_gemm_compute(const matrix_op_t transa, const int m, + const float *A, + const PackedGemmMatrixFP16 &Bp, + const float beta, float *C); +}; + +/** + * restrictions: transa == CblasNoTrans + */ +extern void cblas_gemm_compute(const matrix_op_t transa, const int m, + const float *A, + const PackedGemmMatrixFP16 &Bp, + const float beta, float *C); +extern void cblas_gemm_compute(const matrix_op_t transa, const int m, + const float *A, + const PackedGemmMatrixFP16 &Bp, + const float beta, float *C); + +}; // namespace fbgemm diff --git a/include/fbgemm/FbgemmI8Spmdm.h b/include/fbgemm/FbgemmI8Spmdm.h new file mode 100644 index 0000000..264b70e --- /dev/null +++ b/include/fbgemm/FbgemmI8Spmdm.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include "Utils.h" + +// #define FBGEMM_MEASURE_TIME_BREAKDOWN + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +#include +#include +extern double spmdm_initial_time; +extern double spmdm_transpose_uint8_time; +extern double spmdm_transpose_32xN_time; +extern double spmdm_compute_time; +extern double spmdm_transpose_Nx32_time; +extern double spmdm_run_time; +#endif + +namespace fbgemm2 { + +/** + * @brief A class to represent a matrix in Compressed Sparse Column (CSC) + * format. + * + * The second input matrix of matrix multiplication is usually weight and can + * be sparse, and it's usually more efficient to use CSC format to represent + * the second input matrix. + */ +class CompressedSparseColumn { + public: + CompressedSparseColumn(int num_of_rows, int num_of_cols); + + std::vector& ColPtr() { + return colptr_; + } + std::vector& RowIdx() { + return rowidx_; + } + std::vector& Values() { + return values_; + } + + std::size_t NumOfRows() const { + return num_rows_; + } + std::size_t NumOfCols() const { + return colptr_.size() - 1; + } + std::int32_t NumOfNonZeros() const { + return colptr_.back(); + } + + /** + * @return Total number of non-zero elements as a fraction of total + * elements. + */ + double Density() const; + + /** + * @return True if the number of non-zeros per row is smaller than a small + * threshold. + */ + bool IsHyperSparse() const; + + /** + * @brief Perform dense-matrix * sparse matrix. + * + * C += A (dense matrix) * B (this CSC matrix) if accumulation = true \n + * C = A (dense matrix) * B (this CSC matrix) if accumulation = false + */ + void SpMDM( + const block_type_t& block, + const std::uint8_t* A, + int lda, + bool accumulation, + std::int32_t* C, + int ldc) const; + + private: + const std::size_t num_rows_; + std::vector colptr_; + std::vector rowidx_; + std::vector values_; + + // Cache IsHyperSparse to minimize its overhead. + mutable bool hyper_sparse_; + + // Whether we can reuse the cached hyper_sparse_ is determined by checking + // if NumOfNonZeros() is same as old_nnz_ saved in previous invocation of + // IsHyperSparse call. + mutable std::int32_t old_nnz_; +}; + +} // namespace fbgemm2 diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h new file mode 100644 index 0000000..13f614a --- /dev/null +++ b/include/fbgemm/OutputProcessing-inl.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +template +template +inline int memCopy::f(outT* out, inT* inp, + const block_type_t& block, int ld_out, int ld_in) const { + static_assert( + std::is_same::value, + "input and output data type must be of same type"); + // only copy if destination is not the same as source + if (out + block.row_start*ld_out + block.col_start != inp) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + memcpy(out + block.col_start + i * ld_out, + inp + (i - block.row_start) * ld_in, + block.col_size*sizeof(inT)); + } + } + return nextop_.template f(out, out, block, ld_out, ld_out); +} + +template +template +inline int DoSpmdmOnInpBuffer::f(outT* out, inT* inp, + const block_type_t& block, int ld_out, int ld_in) const { + B_csc_.SpMDM(block, A_, lda_, true, inp, ld_in); + return nextop_.template f(out, inp, block, ld_out, ld_in); +} + +template +template +inline int ReQuantizeOutput::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + static_assert( + std::is_same::value, + "input data type must be of int32_t type"); + if (instSet == inst_set_t::anyarch) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + inT raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)]; + raw -= Aq_zero_point_ * q_col_offsets_[j]; + if (q_row_offsets_) { + raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_; + } + if (bias_) { + raw += bias_[j]; + } + + float ab = raw * C_multiplier_; + long rounded = std::lrintf(ab) + C_zero_point_; + + out[i * ld_out + j] = std::max( + FUSE_RELU ? static_cast(C_zero_point_) : 0l, + std::min(255l, rounded)); + } + } + } else if (instSet == inst_set_t::avx2) { + if (std::is_same::value) { + // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c + // using AVX2 instructions + __m256 multiplier_v = _mm256_set1_ps(C_multiplier_); + + __m256i min_v = _mm256_set1_epi8(std::numeric_limits::min()); + __m256i max_v = _mm256_set1_epi8(std::numeric_limits::max()); + + __m256i A_zero_point_v = _mm256_set1_epi32(Aq_zero_point_); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point_); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point_); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + std::int32_t row_offset = q_row_offsets_ + ? q_row_offsets_[i - block.row_start] * Bq_zero_point_ + : 0; + __m256i row_offset_v = _mm256_set1_epi32(row_offset); + int j = block.col_start; + for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4)); + j += (VLEN * 4)) { + __m256i x_v = _mm256_loadu_si256(reinterpret_cast( + inp + (i - block.row_start) * ld_in + (j - block.col_start))); + __m256i y_v = _mm256_loadu_si256(reinterpret_cast( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 1 * VLEN)); + __m256i z_v = _mm256_loadu_si256(reinterpret_cast( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256(reinterpret_cast( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 3 * VLEN)); + + //if (A_zero_pt != 0) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(q_col_offsets_ + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(q_col_offsets_ + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast( + q_col_offsets_ + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast( + q_col_offsets_ + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + //} + + // if (row_offset != 0) { + x_v = _mm256_sub_epi32(x_v, row_offset_v); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + //} + if (bias_) { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast(bias_ + j))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast(bias_ + j + VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast(bias_ + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast(bias_ + j + 3 * VLEN))); + } + + /* + * Convert int32_t input to FP32 and multiply by FP32 scale. + * Both operations involve statistically unbiased roundings (with + * default MXCSR rounding mode): + * - Large int32_t values can't be exactly represented as FP32. + * CVTDQ2PS instruction on x86 would round it according to nearest + * FP32 value with ties to even (assuming default MXCSR rounding + * mode). + * - Product of two FP32 values is generally not exactly + * representation as an FP32 value, and will be rounded to nearest + * FP32 value with ties to even with default MXCSR rounding mode. + */ + __m256 x_scaled_v = + _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + __m256 y_scaled_v = + _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + __m256 z_scaled_v = + _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + __m256 w_scaled_v = + _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + + /* + * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. + * CVTPS2DQ instruction rounds result according to nearest FP32 value + * with ties to even (assuming default MXCSR rounding mode). However, + * when conversion overflows, it produces INT32_MIN as a result. For + * large positive inputs the result of conversion can become negative, + * which affects the final requantization result. Note that on x86 + * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This + * happens because float(INT32_MAX) rounds to 2**31, which overflows + * int32_t when it is converted back to integer. + * + * Thankfully, we can prove that overflow never happens in this + * requantization scheme. The largest positive input is INT32_MAX + * (2**31 - 1), which turns into 2**31 when converted to float. The + * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the + * result is 2147483520 (compare to INT32_MAX = 2147483647), which + * fits into int32_t without overflow. + */ + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + /* + * Standard final sequence on x86 AVX2: + * - Pack to int16_t and saturate + * - Add zero point + * - Pack to uint8_t and saturate + * - Clamp between qmin and qmax + */ + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), + C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), + C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 + */ + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + /* + * 4x CVTDQ2PS + * 4x MULPS + * 4x CVTPS2DQ + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * 1x PERMD + * --------------------- + * 20 instructions total + */ + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + for (; j < block.col_start + (block.col_size / VLEN * VLEN); + j += VLEN) { + __m256i x_v = _mm256_loadu_si256(reinterpret_cast( + inp + (i - block.row_start) * ld_in + (j - block.col_start))); + + //if (A_zero_pt != 0) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(q_col_offsets_ + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + //} + + // if (row_offset != 0) { + x_v = _mm256_sub_epi32(x_v, row_offset_v); + //} + if (bias_) { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast(bias_ + j))); + } + + __m256 x_scaled_v = + _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); + + /* + * x_clamped_v has results in the following layout so we need to + * permute: x0-3 garbage0-11 x4-7 garbage12-23 + */ + x_clamped_v = + _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + /* + * 1x CVTDQ2PS + * 1x MULPS + * 1x CVTPS2DQ + * 1x PACKSSDW + * 1x PACKUSWB + * 1x PADDW + * 1x PMAXUB + * 1x PMINUB + * 1x PERMD + * --------------------- + * 9 instructions total + */ + _mm_storel_epi64( + reinterpret_cast<__m128i*>(out + i * ld_out + j), + _mm256_castsi256_si128(x_clamped_v)); + } // j loop vectorized + + for (; j < block.col_start + block.col_size; ++j) { + int32_t raw = + inp[(i - block.row_start) * ld_in + (j - block.col_start)]; + // if (A_zero_pt != 0) { + raw -= Aq_zero_point_ * q_col_offsets_[j]; + //} + raw -= row_offset; + if (bias_) { + raw += bias_[j]; + } + + float ab = raw * C_multiplier_; + long rounded = std::lrintf(ab) + C_zero_point_; + + out[i * ld_out + j] = std::max( + FUSE_RELU ? static_cast(C_zero_point_) : 0l, + std::min(255l, rounded)); + } // j loop remainder + } // i loop + } else { + assert(0 && "Not supported yet"); + } + } else { + assert(0 && "Not supported yet"); + } + return nextop_.template f(out, out, block, ld_out, ld_out); +} + +template +template +inline int ReQuantizeForFloat::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + static_assert( + std::is_same::value, + "input data type is of not expected type"); + static_assert( + std::is_same::value, + "output data type is of not expected type"); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start]; + raw -= Aq_zero_point_ * q_col_offsets_[j]; + raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_; + float res = raw * Aq_scale_ * Bq_scale_; + if (bias_) { + res += bias_[j]; + } + out[i * ld_out + j] = res; + if (FUSE_RELU) { + out[i * ld_out + j] = std::max(0.0f, out[i * ld_out + j]); + } + } + } + + return nextop_.template f(out, out, block, ld_out, ld_out); +} diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h new file mode 100644 index 0000000..35cd50a --- /dev/null +++ b/include/fbgemm/PackingTraits-inl.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +/** + * @brief Packing parameter specialization for accumulation into 32-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx2. + */ +template +struct PackingTraits< + T, + std::int32_t, + inst_set_t::avx2, + typename std::enable_if::value>::type> { + static constexpr int MR{12}; ///< Register block for M dimension + static constexpr int NR{8}; ///< Register block for N dimension + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 120}; ///< cache block for M dimension (multiple of MR) + static constexpr int NCB{8}; ///< cache block for N dimension (multiple of NR) + static constexpr int KCB{512}; ///< cache block for K dimension +}; + +/** + * @brief Packing parameter specialization for accumulation into 16-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx2. + */ +template +struct PackingTraits< + T, + std::int16_t, + inst_set_t::avx2, + typename std::enable_if::value>::type> { + static constexpr int MR{3}; ///< Register block for M dimension + static constexpr int NR{16}; ///< Register block for N dimension; Total + ///< register used for N dimension: NCB/NR + + static constexpr int ROW_INTERLEAVE{ + 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 60}; ///< cache block for M dimension (multiple of MR) + static constexpr int NCB{ + 64}; ///< cache block for N dimension (multiple of NR) + static constexpr int KCB{256}; ///< cache block for K dimension +}; + +/** + * @brief Packing parameter specialization for float input and float + * accumulation. + * + * This is picked when template paramtere T is of float type and instruction + * set is avx2. + */ +template <> +struct PackingTraits { + static constexpr int MR{3}; ///< Register block for M dimension + static constexpr int NR{32}; ///< Register block for N dimension + + static constexpr int ROW_INTERLEAVE{1}; ///< No Row interleave. + + static constexpr int MCB{ + 24}; ///< cache block for M dimension (multiple of MR) + static constexpr int NCB{ + 64}; ///< cache block for N dimension (multiple of NR) + static constexpr int KCB{256}; ///< cache block for K dimension +}; + +/** + * @brief Packing parameter specialization for fp16 input and float + * accumulation. + * + * This is picked when template parameter T is of float16 type and instruction + * set is avx2. + */ +template <> +struct PackingTraits { + static constexpr int BCOL{8}; + static constexpr int ROW_INTERLEAVE{1}; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512. + */ +template +struct PackingTraits< + T, + std::int32_t, + inst_set_t::avx512, + typename std::enable_if::value>::type> { + static constexpr int MR{28}; ///< Register block for M dimension + static constexpr int NR{16}; ///< Register block for N dimension + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix + + static constexpr int MCB{ + 140}; ///< cache block for M dimension (multiple of MR) + static constexpr int NCB{ + 16}; ///< cache block for N dimension (multiple of NR) + static constexpr int KCB{512}; ///< cache block for K dimension +}; + +/** + * @brief Packing parameter specialization for accumulation into 16-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512. + */ +template +struct PackingTraits< + T, + std::int16_t, + inst_set_t::avx512, + typename std::enable_if::value>::type> { + static constexpr int MR{6}; ///< Register block for M dimension + static constexpr int NR{32}; ///< Register block for N dimension; Total + ///< register used for N dimension: NCB/NR + + static constexpr int ROW_INTERLEAVE{ + 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 60}; ///< cache block for M dimension (multiple of MR) + static constexpr int NCB{ + 128}; ///< cache block for N dimension (multiple of NR) + static constexpr int KCB{256}; ///< cache block for K dimension +}; diff --git a/include/fbgemm/Types.h b/include/fbgemm/Types.h new file mode 100644 index 0000000..c5c62dd --- /dev/null +++ b/include/fbgemm/Types.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +namespace fbgemm2 { + +typedef struct __attribute__((aligned(2))) __f16 { + uint16_t x; +} float16; + +static inline float16 cpu_float2half_rn(float f) { + float16 ret; + + static_assert(sizeof(unsigned int) == sizeof(float), + "Programming error sizeof(unsigned int) != sizeof(float)"); + + unsigned *xp = reinterpret_cast(&f); + unsigned x = *xp; + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + ret.x = 0x7fffU; + return ret; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + ret.x = sign | 0x7c00U; + return ret; + } + if (u < 0x33000001) { + ret.x = (sign | 0x0000); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + ret.x = (sign | (exponent << 10) | mantissa); + + return ret; +} + +static inline float cpu_half2float(float16 h) { + unsigned sign = ((h.x >> 15) & 1); + unsigned exponent = ((h.x >> 10) & 0x1f); + unsigned mantissa = ((h.x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + unsigned i = ((sign << 31) | (exponent << 23) | mantissa); + float ret; + memcpy(&ret, &i, sizeof(i)); + return ret; +} + +static inline uint8_t tconv(float x, int8_t /*rtype*/) { return int8_t(x); } +static inline uint8_t tconv(float x, uint8_t /*rtype*/) { return uint8_t(x); } +static inline float16 tconv(float x, float16 /*rtype*/) { + return cpu_float2half_rn(x); +} + +template T tconv(T x, T /*rtype*/) { return x; } +} diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h new file mode 100644 index 0000000..22e5a16 --- /dev/null +++ b/include/fbgemm/Utils.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include + +namespace fbgemm2 { + +/** + * @brief Helper struct to type specialize for uint8 and int8 together. + */ +template +struct is_8bit { + static constexpr bool value = + std::is_same::value || std::is_same::value; +}; + +/** + * @brief Typed enum to specify matrix operations. + */ +enum class matrix_op_t { NoTranspose, Transpose }; + +/** + * @brief Typed enum for supported instruction sets. + */ +enum class inst_set_t { anyarch, avx2, avx512 }; + +/** + * @brief Typed enum for implementation type. + * + * ref is reference and opt is optimized. + */ +enum class impl_type_t { ref, opt }; + + +/** + * @brief A struct to represent a block of a matrix. + */ +struct block_type_t { + int row_start; + int row_size; + int col_start; + int col_size; + + std::string toString() const { + std::string out = ""; + out += "row start:" + std::to_string(row_start) + ", "; + out += "row size:" + std::to_string(row_size) + ", "; + out += "col start:" + std::to_string(col_start) + ", "; + out += "col size:" + std::to_string(col_size); + return out; + } +}; + +/** + * @brief A function to compare data in two buffers for closeness/equality. + */ +template +int compare_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol = 1e-3); + +/** + * @brief Debugging helper. + */ +template +void printMatrix( + matrix_op_t trans, + const T* inp, + size_t R, + size_t C, + size_t ld, + std::string name); + +/** + * @brief Top-level routine to transpose a matrix. + * + * This calls transpose_8x8 or transpose_16x16 internally. + */ +void transpose_simd( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst); + +/** + * @brief Transpose a matrix using Intel AVX2. + * + * This is called if the code is running on a CPU with Intel AVX2 support. + */ +void transpose_8x8( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst); + +/** + * @brief Transpose a matrix using Intel AVX512. + * + * This is called if the code is running on a CPU with Intel AVX512 support. + */ +void transpose_16x16( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst); + +} // namespace fbgemm2 diff --git a/src/ExecuteKernel.cc b/src/ExecuteKernel.cc new file mode 100644 index 0000000..0e3d122 --- /dev/null +++ b/src/ExecuteKernel.cc @@ -0,0 +1,12 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "ExecuteKernel.h" +#include +#include "fbgemm/Fbgemm.h" +#include "fbgemm/Utils.h" + +namespace fbgemm2 {} // namespace fbgemm2 diff --git a/src/ExecuteKernel.h b/src/ExecuteKernel.h new file mode 100644 index 0000000..55a2581 --- /dev/null +++ b/src/ExecuteKernel.h @@ -0,0 +1,11 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include "fbgemm/Fbgemm.h" +#include "ExecuteKernelGeneric.h" +#include "ExecuteKernelU8S8.h" diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h new file mode 100644 index 0000000..e83e943 --- /dev/null +++ b/src/ExecuteKernelGeneric.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include "fbgemm/Fbgemm.h" +#include "GenerateKernel.h" + +namespace fbgemm2 { + +/** + * @brief Execute Engine for the macro-kernel and output processing. + * ExecuteKernel is a derived class of CodeGenBase. + */ +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +class ExecuteKernel : public CodeGenBase< + typename packingAMatrix::inpType, + typename packingBMatrix::inpType, + cT, + typename packingBMatrix::accType> { + public: + ExecuteKernel( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + typename packingBMatrix::accType* C_buffer, + int32_t ldc, + const processOutputType& outputProcess); + void execute(int kBlock); + + private: + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& + packedA_; ///< Packed block of matrix A. + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packedB_; ///< Packed matrix B. + int32_t kBlock_; ///< Block ID in the k dimension. + cT* matC_; ///< Output for matrix C. + typename packingAMatrix::accType* + C_buffer_; ///< the accumulation buffer for matrix C. + int32_t ldc_; ///< the leading dimension of matrix C. + const processOutputType& outputProcess_; ///< output processing function for + ///< the C tile in the macro-kernel. +}; + +} // namespace fbgemm2 diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc new file mode 100644 index 0000000..5145869 --- /dev/null +++ b/src/ExecuteKernelU8S8.cc @@ -0,0 +1,354 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "ExecuteKernelU8S8.h" +#include +#include + + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double kernel_time = 0.0; +double postprocessing_time = 0.0; +#endif + +namespace fbgemm2 { + +template +ExecuteKernel< + packingAMatrix, + PackBMatrix, + cT, + processOutputType>:: + ExecuteKernel( + PackMatrix& + packA, + PackMatrix< + PackBMatrix, + int8_t, + typename packingAMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + int32_t* C_buffer, + int32_t ldc, + const processOutputType& outputProcess) + : packedA_(packA), + packedB_(packB), + kBlock_(kBlock), + matC_(matC), + C_buffer_(C_buffer), + ldc_(ldc), + outputProcess_(outputProcess) { + if (cpuinfo_has_x86_avx512f()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NCB; + } else if (cpuinfo_has_x86_avx2()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NCB; + } else { + assert(0 && "unsupported architecure"); + } + C_tile_ = new int32_t[mbSize_ * nbSize_]; +} + +template +void ExecuteKernel< + packingAMatrix, + PackBMatrix, + cT, + processOutputType>::execute(int kBlock) { + // packedA_.printPackedMatrix("packedA from kernel"); + // packedB_.printPackedMatrix("packedB from kernel"); + + int32_t bColBlocks = packedB_.blockCols(); + + int8_t* bBuf; + int8_t* bBuf_pf; + + uint8_t* aBuf = packedA_.getBuf(0); + + int32_t packed_rows_A = packedA_.numPackedRows(); + int32_t row_start_A = packedA_.packedRowStart(); + + bool lastKBlock = packedB_.isThisLastKBlock(kBlock); + bool accum = kBlock > 0; + + typename BaseType::jit_micro_kernel_fp fn; + + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + fn = BaseType::template getOrCreate( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else if (cpuinfo_has_x86_avx2()) { + fn = BaseType::template getOrCreate( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); +#endif + + for (int jb = 0; jb < bColBlocks; ++jb) { + + bBuf = packedB_.getBuf(jb, kBlock); + // prefetch addr of the next packed block of B matrix + bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock); + + // Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as + // matC_ (inplace output processing) + int32_t* C_buffer_row_start = C_buffer_ + + ((C_buffer_ == reinterpret_cast(matC_)) ? row_start_A * ldc_ + : 0); + int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_; + int32_t leadingDim = ldc_; + if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) { + // In case we will access memory past C_buffer_, we use C_tile_ instead. + C_buffer_start = C_tile_; + leadingDim = nbSize_; + } + + fn(aBuf, + bBuf, + bBuf_pf, + C_buffer_start, + packedA_.numPackedCols(), + leadingDim); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + kernel_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Output processing is done only once per rowblock + if (lastKBlock && jb == bColBlocks - 1) { + // When C_tile_ is used for the last column block, we need a separate + // handling for the last column block. + int32_t nSize = + C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); + if (nSize) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + outputProcess_.template f( + matC_, + C_buffer_row_start, + {row_start_A, packed_rows_A, 0, nSize}, + ldc_, + ldc_); + } else if (cpuinfo_has_x86_avx2()) { + outputProcess_.template f( + matC_, + C_buffer_row_start, + {row_start_A, packed_rows_A, 0, nSize}, + ldc_, + ldc_); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } + + if (C_buffer_start == C_tile_) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + outputProcess_.template f( + matC_, + C_tile_, + {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + ldc_, + leadingDim); + } else if (cpuinfo_has_x86_avx2()) { + outputProcess_.template f( + matC_, + C_tile_, + {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + ldc_, + leadingDim); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } + } // output processing + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + postprocessing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + } // for each j block +} +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + ReQuantizeOutput>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + ReQuantizeOutput>; + +template class ExecuteKernel< + PackAWithQuantRowOffset, + PackBMatrix, + float, + ReQuantizeForFloat>; + +template class ExecuteKernel< + PackAWithQuantRowOffset, + PackBMatrix, + float, + ReQuantizeForFloat>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + float, + ReQuantizeForFloat>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + float, + ReQuantizeForFloat>; + +template class ExecuteKernel< + PackAMatrix, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAMatrix, + PackBMatrix, + uint8_t, + ReQuantizeOutput>; + +template class ExecuteKernel< + PackAMatrix, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + DoSpmdmOnInpBuffer< + ReQuantizeOutput::outType, + int32_t, + ReQuantizeOutput>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + DoSpmdmOnInpBuffer< + ReQuantizeOutput::outType, + int32_t, + ReQuantizeOutput>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + float, + DoSpmdmOnInpBuffer< + ReQuantizeForFloat::outType, + int32_t, + ReQuantizeForFloat>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + ReQuantizeOutput>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + uint8_t, + ReQuantizeOutput>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithIm2Col, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithIm2Col, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithQuantRowOffset, + PackBMatrix, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset, + PackBMatrix, + float, + ReQuantizeForFloat>; + +template class ExecuteKernel< + PackAMatrix, + PackBMatrix, + int32_t, + DoNothing>; + +} // namespace fbgemm2 diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h new file mode 100644 index 0000000..0bd7fc5 --- /dev/null +++ b/src/ExecuteKernelU8S8.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include "ExecuteKernel.h" + +namespace fbgemm2 { + +/** + * @brief Execute Engine of uint 8 and int8 matrix + * multiplication for the macro-kernel and output processing. ExecuteKernel is a + * derived class of CodeGenBase. + */ +template +class ExecuteKernel< + packingAMatrix, + PackBMatrix, + cT, + processOutputType> + : public CodeGenBase< + uint8_t, + int8_t, + int32_t, + typename packingAMatrix::accType> { + public: + using BaseType = + CodeGenBase; + /** + * @brief Constructor for initializing the parameters for macro-kernel and + * output processing type. + */ + ExecuteKernel( + PackMatrix& + packA, + PackMatrix< + PackBMatrix, + int8_t, + typename packingAMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + int32_t* C_buffer, + int32_t ldc, + const processOutputType& outputProcess); + void execute(int kBlock); + + ~ExecuteKernel() { + delete[] C_tile_; + } + + private: + PackMatrix& + packedA_; ///< Packed uint8 block of matrix A. + PackMatrix< + PackBMatrix, + int8_t, + typename packingAMatrix::accType>& + packedB_; ///< Packed int8 matrix B. + int32_t kBlock_; ///< Block ID in the k dimension. + cT* matC_; ///< Output for matrix C. + int32_t* C_buffer_; ///< the accumulation buffer for matrix C. + int32_t ldc_; ///< the leading dimension of matrix C. + const processOutputType& outputProcess_; ///< output processing function for + ///< matrix C in the macro-kernel. + int32_t* C_tile_; ///< buffer for the last N block when NCB is not an exact + ///< multiple of N. + int mbSize_; ///< block size in the m dimension. + int nbSize_; ///< block size in the n dimension. +}; + +} // namespace fbgemm2 diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc new file mode 100644 index 0000000..f3bac97 --- /dev/null +++ b/src/Fbgemm.cc @@ -0,0 +1,363 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/Fbgemm.h" +#include +#include +#include "ExecuteKernel.h" + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double packing_time = 0.0; +double computing_time = 0.0; +double run_time = 0.0; +#endif + +using namespace fbgemm2; + +namespace fbgemm2 { + +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +void fbgemmPacked( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + cT* C, + int32_t* C_buffer, + uint32_t ldc, + const processOutputType& outProcess, + int thread_id, + int /* num_threads */) { + static_assert( + std::is_same< + typename packingAMatrix::accType, + typename packingBMatrix::accType>::value, + "Accumulation type of both matrices should be the same"); + + int MCB, KCB; + + // Run time CPU detection + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::KCB; + } else if (cpuinfo_has_x86_avx2()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::KCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + + int MDim = packA.numRows(); + int KDim = packB.numRows(); + + int mBlocks = (MDim + MCB - 1) / MCB; + int kBlocks = (KDim + KCB - 1) / KCB; + + // remainders + int _mc = MDim % MCB; + int _kc = KDim % KCB; + + int kc, mc; + + block_type_t blockA{0, 0, 0, 0}; + + // B must be prepacked + assert(packB.isPrePacked() && "B matrix must be prepacked"); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point t_very_start, + t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); + t_very_start = std::chrono::high_resolution_clock::now(); +#endif + + ExecuteKernel + exeKernelObj(packA, packB, 0, C, C_buffer, ldc, outProcess); + // ToDo: thread based work division + for (int i = 0; i < mBlocks; ++i) { + mc = (i != mBlocks - 1 || _mc == 0) ? MCB : _mc; + for (int k = 0; k < kBlocks; ++k) { + kc = (k != kBlocks - 1 || _kc == 0) ? KCB : _kc; + // pack A matrix + blockA = {i * MCB, mc, k * KCB, kc}; + + packA.pack(blockA); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + packing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + exeKernelObj.execute(k); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + computing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + } + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = + std::chrono::duration_cast(t_end - t_very_start) + .count(); + run_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif +} + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& + packA, + PackMatrix, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& + packA, + PackMatrix, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& packA, + PackMatrix, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int32_t>& + packA, + PackMatrix, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +// 16 bit accumulation functions +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoNothing& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix, uint8_t, int16_t>& packA, + PackMatrix, int8_t, int16_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat& outProcess, + int thread_id, + int num_threads); + +} // namespace fbgemm2 diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc new file mode 100644 index 0000000..7bbfa54 --- /dev/null +++ b/src/FbgemmFP16.cc @@ -0,0 +1,293 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmFP16.h" + +#include + +#include "FbgemmFP16UKernels.h" + +using namespace std; + +namespace fbgemm2 { + +/// class that performs packing of matrix in +/// row-major or col-major format into +/// internal packed blocked-row major format + +/// Todo: make it fast with AVX2 transpose +inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) { + // for (int r = 0; r < nrow; ++r) { + // for (int c = 0; c < ncol; ++c) { + // to[r + c * nrow] = from[r * ldim + c]; + // } + // } + transpose_simd( nrow, ncol, from, ldim, to, nrow ); +} + +struct KernelInfo { + using knl_ptr = funcptr_fp16; + // optimized kernels to cover all cases + static constexpr array kernel = {{ + nullptr, gemmkernel_1x1_AVX2_fA0fB0fC0, + gemmkernel_2x1_AVX2_fA0fB0fC0, gemmkernel_3x1_AVX2_fA0fB0fC0, + gemmkernel_4x1_AVX2_fA0fB0fC0, gemmkernel_5x1_AVX2_fA0fB0fC0, + gemmkernel_6x1_AVX2_fA0fB0fC0, gemmkernel_7x1_AVX2_fA0fB0fC0, + gemmkernel_8x1_AVX2_fA0fB0fC0, gemmkernel_9x1_AVX2_fA0fB0fC0, + gemmkernel_10x1_AVX2_fA0fB0fC0, gemmkernel_11x1_AVX2_fA0fB0fC0, + gemmkernel_12x1_AVX2_fA0fB0fC0, gemmkernel_13x1_AVX2_fA0fB0fC0, + gemmkernel_14x1_AVX2_fA0fB0fC0 + }}; + + // autotuned kernel splits for various cases m = 1:mb_max + // may need re-autotuning for new uarch + static constexpr array, 2>, 121 > partition = { + { + {{ { 0, 0 }, { 0, 0 } } }, + {{ { 1, 1 }, { 0, 0 } } }, + {{ { 2, 1 }, { 0, 0 } } }, + {{ { 3, 1 }, { 0, 0 } } }, + {{ { 4, 1 }, { 0, 0 } } }, + {{ { 5, 1 }, { 0, 0 } } }, + {{ { 6, 1 }, { 0, 0 } } }, + {{ { 7, 1 }, { 0, 0 } } }, + {{ { 8, 1 }, { 0, 0 } } }, + {{ { 9, 1 }, { 0, 0 } } }, + {{ { 10, 1 }, { 0, 0 } } }, + {{ { 11, 1 }, { 0, 0 } } }, + {{ { 12, 1 }, { 0, 0 } } }, + {{ { 13, 1 }, { 0, 0 } } }, + {{ { 14, 1 }, { 0, 0 } } }, + {{ { 8, 1 }, { 7, 1 } } }, + {{ { 10, 1 }, { 6, 1 } } }, + {{ { 11, 1 }, { 6, 1 } } }, + {{ { 12, 1 }, { 6, 1 } } }, + {{ { 11, 1 }, { 8, 1 } } }, + {{ { 11, 1 }, { 9, 1 } } }, + {{ { 12, 1 }, { 9, 1 } } }, + {{ { 11, 2 }, { 0, 0 } } }, + {{ { 12, 1 }, { 11, 1 } } }, + {{ { 12, 2 }, { 0, 0 } } }, + {{ { 13, 1 }, { 12, 1 } } }, + {{ { 13, 2 }, { 0, 0 } } }, + {{ { 14, 1 }, { 13, 1 } } }, + {{ { 14, 2 }, { 0, 0 } } }, + {{ { 11, 2 }, { 7, 1 } } }, + {{ { 10, 3 }, { 0, 0 } } }, + {{ { 12, 2 }, { 7, 1 } } }, + {{ { 12, 2 }, { 8, 1 } } }, + {{ { 11, 3 }, { 0, 0 } } }, + {{ { 13, 2 }, { 8, 1 } } }, + {{ { 13, 2 }, { 9, 1 } } }, + {{ { 13, 2 }, { 10, 1 } } }, + {{ { 13, 2 }, { 11, 1 } } }, + {{ { 13, 2 }, { 12, 1 } } }, + {{ { 13, 3 }, { 0, 0 } } }, + {{ { 14, 2 }, { 12, 1 } } }, + {{ { 14, 2 }, { 13, 1 } } }, + {{ { 11, 3 }, { 9, 1 } } }, + {{ { 11, 3 }, { 10, 1 } } }, + {{ { 11, 4 }, { 0, 0 } } }, + {{ { 12, 3 }, { 9, 1 } } }, + {{ { 12, 3 }, { 10, 1 } } }, + {{ { 13, 3 }, { 8, 1 } } }, + {{ { 13, 3 }, { 9, 1 } } }, + {{ { 13, 3 }, { 10, 1 } } }, + {{ { 13, 3 }, { 11, 1 } } }, + {{ { 13, 3 }, { 12, 1 } } }, + {{ { 13, 4 }, { 0, 0 } } }, + {{ { 14, 3 }, { 11, 1 } } }, + {{ { 11, 4 }, { 10, 1 } } }, + {{ { 12, 4 }, { 7, 1 } } }, + {{ { 14, 4 }, { 0, 0 } } }, + {{ { 12, 4 }, { 9, 1 } } }, + {{ { 12, 4 }, { 10, 1 } } }, + {{ { 12, 4 }, { 11, 1 } } }, + {{ { 13, 4 }, { 8, 1 } } }, + {{ { 13, 4 }, { 9, 1 } } }, + {{ { 13, 4 }, { 10, 1 } } }, + {{ { 13, 4 }, { 11, 1 } } }, + {{ { 11, 5 }, { 9, 1 } } }, + {{ { 13, 5 }, { 0, 0 } } }, + {{ { 14, 4 }, { 10, 1 } } }, + {{ { 12, 5 }, { 7, 1 } } }, + {{ { 12, 5 }, { 8, 1 } } }, + {{ { 14, 4 }, { 13, 1 } } }, + {{ { 14, 5 }, { 0, 0 } } }, + {{ { 12, 5 }, { 11, 1 } } }, + {{ { 13, 5 }, { 7, 1 } } }, + {{ { 11, 6 }, { 7, 1 } } }, + {{ { 13, 5 }, { 9, 1 } } }, + {{ { 13, 5 }, { 10, 1 } } }, + {{ { 13, 5 }, { 11, 1 } } }, + {{ { 13, 5 }, { 12, 1 } } }, + {{ { 13, 6 }, { 0, 0 } } }, + {{ { 12, 6 }, { 7, 1 } } }, + {{ { 12, 6 }, { 8, 1 } } }, + {{ { 12, 6 }, { 9, 1 } } }, + {{ { 12, 6 }, { 10, 1 } } }, + {{ { 12, 6 }, { 11, 1 } } }, + {{ { 12, 7 }, { 0, 0 } } }, + {{ { 13, 6 }, { 7, 1 } } }, + {{ { 13, 6 }, { 8, 1 } } }, + {{ { 13, 6 }, { 9, 1 } } }, + {{ { 13, 6 }, { 10, 1 } } }, + {{ { 13, 6 }, { 11, 1 } } }, + {{ { 13, 6 }, { 12, 1 } } }, + {{ { 13, 7 }, { 0, 0 } } }, + {{ { 12, 7 }, { 8, 1 } } }, + {{ { 12, 7 }, { 9, 1 } } }, + {{ { 14, 6 }, { 10, 1 } } }, + {{ { 12, 7 }, { 11, 1 } } }, + {{ { 13, 7 }, { 5, 1 } } }, + {{ { 13, 7 }, { 6, 1 } } }, + {{ { 13, 7 }, { 7, 1 } } }, + {{ { 13, 7 }, { 8, 1 } } }, + {{ { 13, 7 }, { 9, 1 } } }, + {{ { 13, 7 }, { 10, 1 } } }, + {{ { 13, 7 }, { 11, 1 } } }, + {{ { 13, 7 }, { 12, 1 } } }, + {{ { 12, 8 }, { 8, 1 } } }, + {{ { 12, 8 }, { 9, 1 } } }, + {{ { 12, 8 }, { 10, 1 } } }, + {{ { 12, 8 }, { 11, 1 } } }, + {{ { 12, 9 }, { 0, 0 } } }, + {{ { 11, 9 }, { 10, 1 } } }, + {{ { 13, 8 }, { 6, 1 } } }, + {{ { 13, 8 }, { 7, 1 } } }, + {{ { 13, 8 }, { 8, 1 } } }, + {{ { 13, 8 }, { 9, 1 } } }, + {{ { 13, 8 }, { 10, 1 } } }, + {{ { 13, 8 }, { 11, 1 } } }, + {{ { 12, 9 }, { 8, 1 } } }, + {{ { 13, 9 }, { 0, 0 } } }, + {{ { 12, 9 }, { 10, 1 } } }, + {{ { 12, 9 }, { 11, 1 } } }, + {{ { 12, 10 }, { 0, 0 } } } + } + }; +}; +constexpr array KernelInfo::kernel; +constexpr array, 2>, 121 > KernelInfo::partition; + +// autotuned kernel splits for various cases m = 1:mb_max +void +cblas_gemm_compute(const matrix_op_t transa, const int m, const float *A, + const PackedGemmMatrixFP16 &Bp, const float beta, + float *C) { + // ground truth + assert(cpuinfo_initialize()); + assert(cpuinfo_has_x86_fma3()); + assert(cpuinfo_has_x86_f16c()); + assert(transa == matrix_op_t::NoTranspose); + + // constants + const int n = Bp.numCols(), k = Bp.numRows(), ldc = n; + const int mb_max = 120; + constexpr int simd_width = 8; + constexpr int kernel_ncol_blocks = 1; + constexpr int kernel_ncols = kernel_ncol_blocks * simd_width; + + // private scratchpad storage + static thread_local unique_ptr > scratchpad( + new std::array()); + + GemmParams gp; + for (auto m0 = 0; m0 < m; m0 += mb_max) { + int mb = std::min(mb_max, m - m0); + assert(mb < KernelInfo::partition.size()); + for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) { + + // set up proper accumulation to avoid "Nan" problem + float beta_; + uint64_t accum; + if (k_ind == 0) { + // accumulate of beta != 0.0 + // do not!!! accumulate otherwise + beta_ = beta; + accum = (beta_ == 0.0f) ? 0 : 1; + } else { + // always accumulate with beta_ = 1.0f + beta_ = 1.0f; + accum = 1; + } + + const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind); + + auto m1 = 0; + for (auto c = 0; c < 2; c++) { + + auto kernel_nrows = KernelInfo::partition[mb][c].first; + auto nkernel_nrows = KernelInfo::partition[mb][c].second; + + auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows; + for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { + assert(kernel_nrows * kb < scratchpad->size()); + PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); + + int nbcol = n / Bp.blockColSize(); + gp.k = kb; + gp.A = scratchpad->data(); + gp.B = &(Bp(k_ind, 0)); + gp.beta = &beta_; + gp.accum = accum; + gp.C = &C[m2 * ldc]; + gp.ldc = ldc * sizeof(C[0]); + gp.b_block_cols = nbcol; + gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]); + if ((n % Bp.blockColSize()) == 0) { + KernelInfo::kernel[kernel_nrows](&gp); + } else { + int last_blk_col = nbcol * Bp.blockColSize(); + if (nbcol) { + KernelInfo::kernel[kernel_nrows](&gp); + } + + // leftover + int rem = n - last_blk_col; + assert(rem < kernel_ncols); + int b = (rem % simd_width) ? ((rem + simd_width) / simd_width) + : (rem / simd_width); + assert(b == 1); + if ((rem % simd_width) == 0) { + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = &C[m2 * ldc + last_blk_col]; + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + } else { + // small temporary buffer + float c_tmp[16 * 24] = { 0 }; + assert((16 * 24) > kernel_nrows * kernel_ncols); + + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = c_tmp; + gp.ldc = 8 * sizeof(C[0]); + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + for (int i = 0; i < kernel_nrows; i++) { + // Todo: use assembly + for (int j = last_blk_col; j < n; j++) { + assert(i * 8 + (j - last_blk_col) < + sizeof(c_tmp) / sizeof(c_tmp[0])); + if (accum == 0) { + C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)]; + } else { + C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + + c_tmp[i * 8 + (j - last_blk_col)]; + } + } + } + } + } + } + m1 += kernel_nrows * nkernel_nrows; + } + } + } +} + + +} // namespace fbgemm diff --git a/src/FbgemmFP16UKernels.cc b/src/FbgemmFP16UKernels.cc new file mode 100644 index 0000000..ec1b297 --- /dev/null +++ b/src/FbgemmFP16UKernels.cc @@ -0,0 +1,2203 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmFP16UKernels.h" + +namespace fbgemm2 { + +void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm1,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm1\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm1,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm0,ymm14,ymm1\t\n" +"add r11, 32\t\n" +"add r9,8\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm2\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm2\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm0,ymm14,ymm2\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm1,ymm14,ymm2\t\n" +"add r11, 32\t\n" +"add r9,16\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm3\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm0,ymm14,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm1,ymm14,ymm3\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm2,ymm14,ymm3\t\n" +"add r9,24\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm4\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm0,ymm14,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm1,ymm14,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm2,ymm14,ymm4\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm3,ymm14,ymm4\t\n" +"add r9,32\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm5\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm0,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm1,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm2,ymm14,ymm5\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm3,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm4,ymm14,ymm5\t\n" +"add r9,40\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm6\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm0,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm1,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm2,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm3,ymm14,ymm6\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm4,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm5,ymm14,ymm6\t\n" +"add r9,48\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm7\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm0,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm1,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm2,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm3,ymm14,ymm7\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm4,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm5,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm6,ymm14,ymm7\t\n" +"add r9,56\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm8\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm0,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm1,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm2,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm3,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm4,ymm14,ymm8\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm5,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm6,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm7,ymm14,ymm8\t\n" +"add r9,64\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm9\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm0,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm1,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm2,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm3,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm4,ymm14,ymm9\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm5,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm6,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm7,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm8,ymm14,ymm9\t\n" +"add r9,72\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm10\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm0,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm1,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm2,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm3,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm4,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm5,ymm14,ymm10\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm6,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm7,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm8,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm9,ymm14,ymm10\t\n" +"add r9,80\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm11\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm0,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm1,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm2,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm3,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm4,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm5,ymm14,ymm11\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm6,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm7,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm8,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm9,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm10,ymm14,ymm11\t\n" +"add r9,88\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm12\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm0,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm1,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm2,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm3,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm4,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm5,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm6,ymm14,ymm12\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm7,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm8,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm9,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+88]\t\n" +"vfmadd231ps ymm10,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+92]\t\n" +"vfmadd231ps ymm11,ymm14,ymm12\t\n" +"add r9,96\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" +"vxorps ymm12,ymm12,ymm12\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm12,ymm15,ymm13\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm0,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm1,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm2,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm3,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm4,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm5,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm6,ymm14,ymm13\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm7,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm8,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+88]\t\n" +"vfmadd231ps ymm9,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+92]\t\n" +"vfmadd231ps ymm10,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+96]\t\n" +"vfmadd231ps ymm11,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+100]\t\n" +"vfmadd231ps ymm12,ymm14,ymm13\t\n" +"add r9,104\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" +"vxorps ymm12,ymm12,ymm12\t\n" +"vxorps ymm13,ymm13,ymm13\t\n" + +"mov r11, 0\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm12,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm13,ymm15,ymm14\t\n" +"add r9,56\t\n" +"add r11, 16\t\n" +"inc r14\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} + +} // namespace fbgemm2 diff --git a/src/FbgemmFP16UKernels.h b/src/FbgemmFP16UKernels.h new file mode 100644 index 0000000..bf7f247 --- /dev/null +++ b/src/FbgemmFP16UKernels.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef FBGEMM_UKERNELS +#define FBGEMM_UKERNELS +#include +#include +#include +#include "fbgemm/Types.h" + +namespace fbgemm2 { + +using fp16 = float16; +using fp32 = float; +struct GemmParams {uint64_t k; float *A; const fp16 *B; +float *beta; uint64_t accum; float *C; uint64_t ldc; +uint64_t b_block_cols; uint64_t b_block_size;}; +void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp); +typedef void (* funcptr_fp16) (GemmParams *gp); +; + +} // namespace fbgemm2 + +#endif diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc new file mode 100644 index 0000000..54e2272 --- /dev/null +++ b/src/FbgemmI8Depthwise.cc @@ -0,0 +1,1953 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmI8Depthwise.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace std; + +namespace fbgemm2 +{ + +static array, 8> masks = {{ + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}}; + +template +PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( + int K, const int8_t *smat) + : K_(K) { + // Transpose the input matrix to make packing faster. + vector smat_transposed(K * KERNEL_PROD); + for (int i = 0; i < KERNEL_PROD; ++i) { + for (int j = 0; j < K; ++j) { + smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; + } + } + + // Allocate packed arrays + constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; + pmat_ = static_cast(aligned_alloc( + 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t))); + + // Pack input matrix + // The layout is optimized to use vpmaddubsw efficiently (see + // madd_epi16x4_packed function). + // For a group of 32 channels, we have 10 32B SIMD registers. + // Denote ith channel jth filter as (i, j) + // 0th SIMD register: + // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) + // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) + // 1st SIMD register: + // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) + // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) + // 2nd SIMD register: + // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) + // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) + // 3rd SIMD register: + // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) + // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) + // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter + // coefficients + // ... + // + // REMAINDER + // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9 + // 8th SIMD register: + // (0, 8), zero, ..., (7, 8), zero + // (16, 8), zero, ..., (23, 8), zero + // 9th SIMD register: + // (8, 8), zero, ..., (15, 8), zero + // (24, 8), zero, ..., (31, 8), zero + // We use madd_epi16_packed for this case + // + // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10 + // 8th SIMD register: + // (0, 8), (0, 9), ..., (7, 8), (7, 9) + // (16, 8), (16, 9), ..., (23, 8), (23, 9) + // 9th SIMD register: + // (8, 8), (8, 9), ..., (15, 8), (15, 9) + // (24, 8), (24, 9), ..., (31, 8), (31, 9) + // + // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11 + // 8th SIMD register: + // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero + // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero + // 9th SIMD register: + // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero + // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero + // 10th SIMD register: + // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero + // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero + // 11th SIMD register: + // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero + // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero + for (int k1 = 0; k1 < K; k1 += 32) { + array<__m256i, KERNEL_PROD> b_v; + int remainder = K - k1; + if (remainder < 32) { + __m256i mask_v = _mm256_loadu_si256( + reinterpret_cast(masks[remainder / 4].data())); + for (int i = 0; i < KERNEL_PROD; ++i) { + b_v[i] = _mm256_maskload_epi32( + reinterpret_cast(smat_transposed.data() + i * K + k1), + mask_v); + } + } else { + for (int i = 0; i < KERNEL_PROD; ++i) { + b_v[i] = _mm256_lddqu_si256(reinterpret_cast( + smat_transposed.data() + i * K + k1)); + } + } + + // Interleave 2 SIMD registers + array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi16; + __m256i zero_v = _mm256_setzero_si256(); + for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) { + if (2 * i + 1 >= KERNEL_PROD) { + b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], zero_v); + } else { + b_interleaved_epi16[2 * i] = + _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); + } + } + + // Interleave 4 SIMD registers + array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi32; + for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) { + b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + } + for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) { + b_interleaved_epi32[i] = b_interleaved_epi16[i]; + } + + for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>( + &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]), + b_interleaved_epi32[i]); + } + } +} + +template +PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() +{ + free(pmat_); +} + +template class PackedDepthWiseConvMatrix<3 * 3>; +template class PackedDepthWiseConvMatrix<3 * 3 * 3>; + +// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template +static inline __attribute__((always_inline)) +void madd_epi16x4_packed( + __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v, + const __m256i* b, + __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 + a2 * b2 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template +static inline __attribute__((always_inline)) +void madd_epi16x3_packed( + __m256i a0_v, __m256i a1_v, __m256i a2_v, + const __m256i* b, + __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template +static inline __attribute__((always_inline)) void +madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v, + __m256i *c1_v, __m256i *c2_v, __m256i *c3_v, + __m256i *a_sum = nullptr) { + __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// c = a0 * b0 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template +static inline __attribute__((always_inline)) void +madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v, + __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// K is the number of accumulations we're doing +template +static inline __attribute__((always_inline)) void +inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C, + int remainder, __m256i *a_sum = nullptr) { + array<__m256i, 4> c, c_temp; + array<__m256i, 2> a_sum_temp{}; + + int k = 0; + if (K >= 4) { + madd_epi16x4_packed(a_v[0], a_v[1], a_v[2], a_v[3], Bp, + &c[0], &c[1], &c[2], &c[3], a_sum_temp.data()); + + for (k = 4; k < K / 4 * 4; k += 4) { + madd_epi16x4_packed(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3], + Bp + k, &c_temp[0], &c_temp[1], &c_temp[2], + &c_temp[3], a_sum_temp.data()); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + } else { + c[0] = _mm256_setzero_si256(); + c[1] = _mm256_setzero_si256(); + c[2] = _mm256_setzero_si256(); + c[3] = _mm256_setzero_si256(); + } + + if (K - k == 3) { + madd_epi16x3_packed(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k, + &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3], + a_sum_temp.data()); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); + c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); + c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); + c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); + + if (K - k == 0 || K - k == 3) { + c[0] = c_temp[0]; + c[1] = c_temp[1]; + c[2] = c_temp[2]; + c[3] = c_temp[3]; + } else { + if (K - k == 1) { + madd_epi16_packed(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], + a_sum_temp.data()); + } else if (K - k == 2) { + madd_epi16x2_packed(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], + &c[2], &c[3], a_sum_temp.data()); + } + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + if (REMAINDER) { + for (int r = 0; r < remainder / 8; ++r) { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + r * 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)), + c[r])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]); + } + } + } else { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C), + _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)), + c[0])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 16), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 24), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]); + } + } + + if (SUM_A) { + a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); + a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); + a_sum[2] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); + a_sum[3] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); + } +} + +template +static inline __attribute__((always_inline)) +void inner_prod_3x3_packed_(const __m256i* a_v, + const __m256i* Bp, + int32_t* C, + int remainder, + __m256i* a_sum = nullptr) { + return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, + a_sum); +} + +// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different +// row_offsets for each row because of depth-wise convolution +template +static inline __attribute__((always_inline)) void requantize_( + int32_t A_zero_point, + const float* C_multiplier, + int32_t C_zero_point, + const int32_t* C_int32, + uint8_t* C_uint8, + int n, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias) { + __m256 multiplier_v = _mm256_setzero_ps(); + if (!PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_set1_ps(*C_multiplier); + } + + __m256i min_v = _mm256_set1_epi8(numeric_limits::min()); + __m256i max_v = _mm256_set1_epi8(numeric_limits::max()); + + __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + int j = 0; + for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast(C_int32 + j)); + __m256i y_v = _mm256_loadu_si256( + reinterpret_cast(C_int32 + j + VLEN)); + __m256i z_v = _mm256_loadu_si256( + reinterpret_cast(C_int32 + j + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256( + reinterpret_cast(C_int32 + j + 3 * VLEN)); + + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j))); + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); + x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j + VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast( + col_offsets + j + 2 * VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast( + col_offsets + j + 3 * VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v); + + if (HAS_BIAS) { // static if + x_v = _mm256_add_epi32( + x_v, _mm256_loadu_si256(reinterpret_cast(bias + j))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast(bias + j + VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast(bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast(bias + j + 3 * VLEN))); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN); + } + __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); + } + __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); + } + __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + for ( ; j < n / VLEN * VLEN; j += VLEN) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast(C_int32 + j)); + + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j))); + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); + x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + + if (HAS_BIAS) { // static if + x_v = _mm256_add_epi32( + x_v, _mm256_loadu_si256(reinterpret_cast(bias + j))); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); + + x_clamped_v = + _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + _mm_storel_epi64( + reinterpret_cast<__m128i*>(C_uint8 + j), + _mm256_castsi256_si128(x_clamped_v)); + } // j loop vectorized + + for ( ; j < n; ++j) { + int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j]; + if (HAS_BIAS) { // static if + raw += bias[j]; + } + + float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; + long rounded = lrintf(ab) + C_zero_point; + + C_uint8[j] = std::max( + FUSE_RELU ? static_cast(C_zero_point) : 0l, + std::min(255l, rounded)); + } +} + +template +static inline __attribute__((always_inline)) void +requantize_(int32_t A_zero_point, float C_multiplier, + int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8, + int n, const int32_t *row_offsets, const int32_t *col_offsets, + const int32_t *bias) { + requantize_( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8, + n, + row_offsets, + col_offsets, + bias); +} + +template +static inline __attribute__((always_inline)) void +requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier, + int32_t C_zero_point, const int32_t *C_int32, + uint8_t *C_uint8, int n, const int32_t *row_offsets, + const int32_t *col_offsets, const int32_t *bias) { + requantize_( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8, + n, + row_offsets, + col_offsets, + bias); +} + +template +static inline __attribute__((always_inline)) __m256i +load_a(const uint8_t* A, __m256i mask_v) { + if (REMAINDER) { + return _mm256_maskload_epi32(reinterpret_cast(A), mask_v); + } else { + return _mm256_lddqu_si256(reinterpret_cast(A)); + } +} + +template +static inline __attribute__((always_inline)) void +inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in, + const uint8_t *A, int32_t A_zero_point, const int8_t *Bp, + const int32_t *B_zero_point, int32_t *C, int remainder, + int32_t *row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast(masks[remainder / 4].data())); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + array<__m256i, 9> a_v = { + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + }; + + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a(A + (0 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a(A + (0 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a(A + (0 * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a(A + (1 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a(A + (1 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a(A + (1 * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a(A + (2 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a(A + (2 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[8] = load_a(A + (2 * W + 2) * K, mask_v); + } + } + + array<__m256i, 4> a_sum; + inner_prod_3x3_packed_( + a_v.data(), reinterpret_cast(Bp), C, remainder, + a_sum.data()); + if (SUM_A) { + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template +static inline __attribute__((always_inline)) void +inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in, + int w_in, const uint8_t *A, int32_t A_zero_point, + const int8_t *Bp, const int32_t *B_zero_point, + int32_t *C, int remainder, int32_t *row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast(masks[remainder / 4].data())); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + array<__m256i, 8> a_v; + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a(A + ((0 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a(A + ((0 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a(A + ((0 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a(A + ((0 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a(A + ((0 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a(A + ((0 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a(A + ((0 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a(A + ((0 * H + 2) * W + 1) * K, mask_v); + } + } + } + + array<__m256i, 4> a_sum; + inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(), + reinterpret_cast(Bp), + C, remainder, a_sum.data()); + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[0] = load_a(A + ((0 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[1] = load_a(A + ((1 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[2] = load_a(A + ((1 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[3] = load_a(A + ((1 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[4] = load_a(A + ((1 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[5] = load_a(A + ((1 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[6] = load_a(A + ((1 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[7] = load_a(A + ((1 * H + 2) * W + 0) * K, mask_v); + } + } + } + + array<__m256i, 4> a_sum_temp; + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast(Bp) + 8, C, remainder, + a_sum_temp.data()); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[0] = load_a(A + ((1 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[1] = load_a(A + ((1 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[2] = load_a(A + ((2 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[3] = load_a(A + ((2 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[4] = load_a(A + ((2 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[5] = load_a(A + ((2 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[6] = load_a(A + ((2 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[7] = load_a(A + ((2 * H + 1) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast(Bp) + 16, C, remainder, + a_sum_temp.data()); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a(A + ((2 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a(A + ((2 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a(A + ((2 * H + 2) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast(Bp) + 24, C, remainder, + a_sum_temp.data()); + + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template +static inline __attribute__((always_inline)) +void depthwise_3x3_kernel_(int H, int W, int K, int h, int w, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t *col_offsets, + const int32_t *bias) +{ + constexpr int S = 3; + constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3_packed_( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, &B_zero_point, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3_packed_( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, &B_zero_point, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_ + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + (h * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +template +static inline __attribute__((always_inline)) +void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t *col_offsets, + const int32_t *bias) +{ + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_( + T, H, W, K, t_in, h_in, w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point, + Bp + k * 28, &B_zero_point, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_( + T, H, W, K, t_in, h_in, w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point, + Bp + k * 28, &B_zero_point, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_ + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +template +static inline __attribute__((always_inline)) void +depthwise_3x3_per_channel_quantization_kernel_( + int H, int W, int K, int h, int w, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + const int32_t *B_zero_point, const int8_t *Bp, + const float *C_multiplier, int32_t C_zero_point, + int32_t *C_int32, uint8_t *C_uint8, + int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) { + constexpr int S = 3; + constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3_packed_( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, B_zero_point + k, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3_packed_( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, B_zero_point + k, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_per_channel_ + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + (h * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +static pair closest_factors_(int n) { + int a = (int)std::sqrt(n); + while (n % a != 0) { + a--; + } + return { a, n / a }; // a <= n / a +} + +// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 +// This implemntation should be general enough to handle not just 3x3 but other +// filter shapes by parameterizing with R and S but restricting it to just 3x3 +// for now. +template +static inline __attribute__((always_inline)) +void depthwise_3x3_pad_1_(int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + int32_t B_zero_point, const Packed3x3ConvMatrix &B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + const int32_t *col_offsets, const int32_t *bias, + int thread_id, int num_threads) { + assert(K % 8 == 0); + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + if (h_begin == 0) { + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { + if (w_begin == 0) { + w = 0; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + if (h_end == H_OUT) { + h = H_OUT - 1; + w = 0; + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + } // for each n +}; + +template +static inline __attribute__((always_inline)) +void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + int32_t B_zero_point, + const Packed3x3x3ConvMatrix &B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + const int32_t *col_offsets, const int32_t *bias, + int thread_id, int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + C_temp = + FUSE_RESCALE + ? C_int32 + : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3x3_kernel_( + T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point, + A_base, B_zero_point, Bp, C_multiplier, + C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets, + bias); + } // w + } // h + } // t + } // for each n +}; + +template +static inline __attribute__((always_inline)) void +depthwise_3x3_per_channel_quantization_pad_1_( + int N, int H, int W, int K, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point, + const Packed3x3ConvMatrix &B, const float *C_multiplier, + int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8, + const int32_t *col_offsets, const int32_t *bias, int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + if (h_begin == 0) { + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { + if (w_begin == 0) { + w = 0; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + if (h_end == H_OUT) { + h = H_OUT - 1; + w = 0; + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + } // for each n +}; + +// assumption: W > 3 and H > 3 +void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, const uint8_t* A, + const Packed3x3ConvMatrix& B, + int32_t* C, + int thread_id, + int num_threads) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } +} + +void depthwise_3x3_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads, + bool fuse_relu) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (fuse_relu) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } + } else { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } + } +} + +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, const uint8_t* A, + const Packed3x3x3ConvMatrix& B, + int32_t* C, + int thread_id, + int num_threads) { + depthwise_3x3x3_pad_1_( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); +} + +static void depthwise_3x3x3_pad_1_( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + depthwise_3x3x3_pad_1_( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); +} + +static void depthwise_3x3x3_pad_1_relu_fused_( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + depthwise_3x3x3_pad_1_( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); +} + +void depthwise_3x3x3_pad_1( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, int num_threads) { + // If we inline the following two functions, I see stack overflow. + if (fuse_relu) { + depthwise_3x3x3_pad_1_relu_fused_( + N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A, + B_zero_point, B, C_multiplier, C_zero_point, C, + col_offsets, bias, thread_id, num_threads); + } else { + depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w, + A_zero_point, A, B_zero_point, B, C_multiplier, + C_zero_point, C, col_offsets, bias, + thread_id, num_threads); + } +} + +void depthwise_3x3_per_channel_quantization_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp, + const float *C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } +} + +} // namespace fbgemm2 diff --git a/src/FbgemmI8Depthwise.h b/src/FbgemmI8Depthwise.h new file mode 100644 index 0000000..bc62c84 --- /dev/null +++ b/src/FbgemmI8Depthwise.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +namespace fbgemm2 +{ + +// KERNEL_PROD is the product of all kernels. +// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3. +template +class PackedDepthWiseConvMatrix +{ + public: + // smat in RSG layout + PackedDepthWiseConvMatrix(int K, const std::int8_t *smat); + virtual ~PackedDepthWiseConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + private: + int K_; + std::int8_t* pmat_; +}; // Packed3x3ConvMatrix + +using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>; +using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * @params A The input image in NHWK layout + * @params Bp The pre-packed filter + */ +void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization. + */ +void depthwise_3x3_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp, + float C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + int thread_id = 0, int num_threads = 1, bool fuse_relu = false); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization and uses per-channel quantization. + */ +void depthwise_3x3_per_channel_quantization_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const std::int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp, + const float *C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + int thread_id = 0, int num_threads = 1); + +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, int num_threads = 1); + +void depthwise_3x3x3_pad_1( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp, + float C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); + +} // namespace fbgemm2 diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc new file mode 100644 index 0000000..723a467 --- /dev/null +++ b/src/FbgemmI8Spmdm.cc @@ -0,0 +1,508 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8Spmdm.h" + +#include +#include +#include +#include +#include + +#include + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double spmdm_initial_time = 0.0; +double spmdm_transpose_uint8_time = 0.0; +double spmdm_transpose_32xN_time = 0.0; +double spmdm_compute_time = 0.0; +double spmdm_transpose_Nx32_time = 0.0; +double spmdm_run_time = 0.0; +#endif + +using namespace std; + +namespace fbgemm2 { + +CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols) + : num_rows_(num_of_rows), + colptr_(num_of_cols + 1), + hyper_sparse_(false), + old_nnz_(-1) {} + +double CompressedSparseColumn::Density() const { + return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols()); +} + +bool CompressedSparseColumn::IsHyperSparse() const { + if (NumOfNonZeros() != old_nnz_) { + old_nnz_ = NumOfNonZeros(); + // The number of non-zero per row is very small. + hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08; + } + + return hyper_sparse_; +} + +static void transpose_8rows( + int N, + const uint8_t* src, + int ld_src, + uint8_t* dst, + int ld_dst) { + constexpr int M = 8; + int j; + // vectorized loop + for (j = 0; j < N / 32 * 32; j += 32) { + // a : a0 a1 ... a31 + // b : b0 b1 ... b31 + // c : c0 c1 ... c31 + // d : d0 d1 ... d31 + __m256i a = _mm256_lddqu_si256( + reinterpret_cast(src + j + 0 * ld_src)); + __m256i b = _mm256_lddqu_si256( + reinterpret_cast(src + j + 1 * ld_src)); + __m256i c = _mm256_lddqu_si256( + reinterpret_cast(src + j + 2 * ld_src)); + __m256i d = _mm256_lddqu_si256( + reinterpret_cast(src + j + 3 * ld_src)); + __m256i e = _mm256_lddqu_si256( + reinterpret_cast(src + j + 4 * ld_src)); + __m256i f = _mm256_lddqu_si256( + reinterpret_cast(src + j + 5 * ld_src)); + __m256i g = _mm256_lddqu_si256( + reinterpret_cast(src + j + 6 * ld_src)); + __m256i h = _mm256_lddqu_si256( + reinterpret_cast(src + j + 7 * ld_src)); + + // even-odd interleaving + // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23 + // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31 + // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23 + // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31 + __m256i ab_lo = _mm256_unpacklo_epi8(a, b); + __m256i ab_hi = _mm256_unpackhi_epi8(a, b); + __m256i cd_lo = _mm256_unpacklo_epi8(c, d); + __m256i cd_hi = _mm256_unpackhi_epi8(c, d); + __m256i ef_lo = _mm256_unpacklo_epi8(e, f); + __m256i ef_hi = _mm256_unpackhi_epi8(e, f); + __m256i gh_lo = _mm256_unpacklo_epi8(g, h); + __m256i gh_hi = _mm256_unpackhi_epi8(g, h); + + // 4-row interleaving but permuted at 128-bit granularity + // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19 + // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23 + // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27 + // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31 + __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo); + __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo); + __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi); + __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi); + __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo); + __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo); + __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi); + __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi); + + // 8-row interleaving + __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0); + __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0); + __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1); + __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1); + __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2); + __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2); + __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3); + __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3); + + // Storing with 128-bit lanes are permuted so that everything is in order + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst), + _mm256_castsi256_si128(y0)); + *reinterpret_cast(dst + (j + 1) * ld_dst) = + _mm256_extract_epi64(y0, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst), + _mm256_castsi256_si128(y1)); + *reinterpret_cast(dst + (j + 3) * ld_dst) = + _mm256_extract_epi64(y1, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst), + _mm256_castsi256_si128(y2)); + *reinterpret_cast(dst + (j + 5) * ld_dst) = + _mm256_extract_epi64(y2, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst), + _mm256_castsi256_si128(y3)); + *reinterpret_cast(dst + (j + 7) * ld_dst) = + _mm256_extract_epi64(y3, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst), + _mm256_castsi256_si128(y4)); + *reinterpret_cast(dst + (j + 9) * ld_dst) = + _mm256_extract_epi64(y4, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst), + _mm256_castsi256_si128(y5)); + *reinterpret_cast(dst + (j + 11) * ld_dst) = + _mm256_extract_epi64(y5, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst), + _mm256_castsi256_si128(y6)); + *reinterpret_cast(dst + (j + 13) * ld_dst) = + _mm256_extract_epi64(y6, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst), + _mm256_castsi256_si128(y7)); + *reinterpret_cast(dst + (j + 15) * ld_dst) = + _mm256_extract_epi64(y7, 1); + *reinterpret_cast(dst + (j + 16) * ld_dst) = + _mm256_extract_epi64(y0, 2); + *reinterpret_cast(dst + (j + 17) * ld_dst) = + _mm256_extract_epi64(y0, 3); + *reinterpret_cast(dst + (j + 18) * ld_dst) = + _mm256_extract_epi64(y1, 2); + *reinterpret_cast(dst + (j + 19) * ld_dst) = + _mm256_extract_epi64(y1, 3); + *reinterpret_cast(dst + (j + 20) * ld_dst) = + _mm256_extract_epi64(y2, 2); + *reinterpret_cast(dst + (j + 21) * ld_dst) = + _mm256_extract_epi64(y2, 3); + *reinterpret_cast(dst + (j + 22) * ld_dst) = + _mm256_extract_epi64(y3, 2); + *reinterpret_cast(dst + (j + 23) * ld_dst) = + _mm256_extract_epi64(y3, 3); + *reinterpret_cast(dst + (j + 24) * ld_dst) = + _mm256_extract_epi64(y4, 2); + *reinterpret_cast(dst + (j + 25) * ld_dst) = + _mm256_extract_epi64(y4, 3); + *reinterpret_cast(dst + (j + 26) * ld_dst) = + _mm256_extract_epi64(y5, 2); + *reinterpret_cast(dst + (j + 27) * ld_dst) = + _mm256_extract_epi64(y5, 3); + *reinterpret_cast(dst + (j + 28) * ld_dst) = + _mm256_extract_epi64(y6, 2); + *reinterpret_cast(dst + (j + 29) * ld_dst) = + _mm256_extract_epi64(y6, 3); + *reinterpret_cast(dst + (j + 30) * ld_dst) = + _mm256_extract_epi64(y7, 2); + *reinterpret_cast(dst + (j + 31) * ld_dst) = + _mm256_extract_epi64(y7, 3); + } + + // scalar loop for remainder + for (; j < N; ++j) { + for (int i = 0; i < M; ++i) { + dst[j * ld_dst + i] = src[j + i * ld_src]; + } + } +} + +// TODO: fallback when AVX2 is not available +void CompressedSparseColumn::SpMDM( + const block_type_t& block, + const uint8_t* A, + int lda, + bool accumulation, + int32_t* C, + int ldc) const { + int K = NumOfRows(); + int N = block.col_size; + + if (K == 0 || N == 0) { + return; + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point t_very_start, + t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); + t_very_start = std::chrono::high_resolution_clock::now(); +#endif + + uint8_t A_buffer[K * 32] __attribute__((aligned(64))); + int32_t C_buffer[N * 32] __attribute__((aligned(64))); + + // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for + // each non-zero in B, we'd need to access the corresponding column in A. + // This results in strided access, which we want to avoid. + // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T + + if (IsHyperSparse()) { + // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. + // If NNZ/K is small, it's not worth doing transpose so we just use this + // scalar loop. + if (!accumulation) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; + ++j) { + C[(i - block.row_start) * ldc + j - block.col_start] = 0; + } + } + } + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = block.row_start; i < block.row_start + block.row_size; + ++i) { + C[(i - block.row_start) * ldc + j - block.col_start] += + A[i * lda + row] * w; + } + } + } // for each column of B + return; + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + spmdm_initial_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Take 32 rows at a time + int i_end = block.row_start + block.row_size; + for (int i1 = block.row_start; i1 < i_end; i1 += 32) { + // Transpose 32 x K submatrix of A + if (i_end - i1 < 32) { + uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64))); + for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) { + transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); + } + + for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) { + memcpy( + A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t)); + } + memset( + A_temp_buffer + (i_end - i1) * K, + 0, + (32 - (i_end - i1)) * K * sizeof(uint8_t)); + for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { + transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); + } + } else { + for (int i2 = 0; i2 < 32; i2 += 8) { + transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); + } + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + spmdm_transpose_uint8_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + if (accumulation) { + // Transpose 32 x N submatrix of C to fill N x 32 C_buffer + transpose_simd( + std::min(32, i_end - i1), + N, + reinterpret_cast(C + (i1 - block.row_start) * ldc), + ldc, + reinterpret_cast(C_buffer), + 32); + } else { + memset(C_buffer, 0, N * 32 * sizeof(int32_t)); + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + spmdm_transpose_32xN_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + for (int j = 0; j < block.col_size; ++j) { + int j_start = j + block.col_start; + int k = colptr_[j_start]; + int k_end_aligned = + colptr_[j_start] + (colptr_[j_start + 1] - colptr_[j_start]) / 4 * 4; + + for (; k < k_end_aligned; k += 4) { + __m256i w = + _mm256_set1_epi32(*(reinterpret_cast(&values_[k]))); + array<__m256i, 4> a; + a[0] = _mm256_load_si256( + reinterpret_cast(&A_buffer[rowidx_[k + 0] * 32])); + a[1] = _mm256_load_si256( + reinterpret_cast(&A_buffer[rowidx_[k + 1] * 32])); + a[2] = _mm256_load_si256( + reinterpret_cast(&A_buffer[rowidx_[k + 2] * 32])); + a[3] = _mm256_load_si256( + reinterpret_cast(&A_buffer[rowidx_[k + 3] * 32])); + + __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); + __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); + __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); + __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); + + a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); + a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); + a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); + a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); + + array<__m256i, 4> ab; + ab[0] = _mm256_maddubs_epi16(a[0], w); + ab[1] = _mm256_maddubs_epi16(a[1], w); + ab[2] = _mm256_maddubs_epi16(a[2], w); + ab[3] = _mm256_maddubs_epi16(a[3], w); + + __m256i one = _mm256_set1_epi16(1); + ab[0] = _mm256_madd_epi16(ab[0], one); + ab[1] = _mm256_madd_epi16(ab[1], one); + ab[2] = _mm256_madd_epi16(ab[2], one); + ab[3] = _mm256_madd_epi16(ab[3], one); + + array<__m256i, 4> t; + t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); + t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); + t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); + t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); + + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 0 * 8])), + t[0])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 1 * 8])), + t[1])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 2 * 8])), + t[2])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 3 * 8])), + t[3])); + } + + int remainder = colptr_[j_start + 1] - k; + assert(remainder < 4); + if (remainder > 0) { + int32_t temp_w = 0; + for (int r = 0; r < remainder; ++r) { + (reinterpret_cast(&temp_w))[r] = values_[k + r]; + } + __m256i w = _mm256_set1_epi32(temp_w); + array<__m256i, 4> a; + a[0] = _mm256_load_si256( + reinterpret_cast(&A_buffer[rowidx_[k + 0] * 32])); + a[1] = remainder > 1 + ? _mm256_load_si256(reinterpret_cast( + &A_buffer[rowidx_[k + 1] * 32])) + : _mm256_setzero_si256(); + a[2] = remainder > 2 + ? _mm256_load_si256(reinterpret_cast( + &A_buffer[rowidx_[k + 2] * 32])) + : _mm256_setzero_si256(); + a[3] = _mm256_setzero_si256(); + + __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); + __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); + __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); + __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); + + a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); + a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); + a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); + a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); + + array<__m256i, 4> ab; + ab[0] = _mm256_maddubs_epi16(a[0], w); + ab[1] = _mm256_maddubs_epi16(a[1], w); + ab[2] = _mm256_maddubs_epi16(a[2], w); + ab[3] = _mm256_maddubs_epi16(a[3], w); + + __m256i one = _mm256_set1_epi16(1); + ab[0] = _mm256_madd_epi16(ab[0], one); + ab[1] = _mm256_madd_epi16(ab[1], one); + ab[2] = _mm256_madd_epi16(ab[2], one); + ab[3] = _mm256_madd_epi16(ab[3], one); + + array<__m256i, 4> t; + t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); + t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); + t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); + t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); + + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 0 * 8])), + t[0])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 1 * 8])), + t[1])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 2 * 8])), + t[2])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast( + &C_buffer[j * 32 + 3 * 8])), + t[3])); + } + } // for each column of B + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + spmdm_compute_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Transpose N x 32 C_buffer to fill 32 x N submatrix of C + transpose_simd( + N, + std::min(32, i_end - i1), + reinterpret_cast(C_buffer), + 32, + reinterpret_cast(C + (i1 - block.row_start) * ldc), + ldc); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + spmdm_transpose_Nx32_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = + std::chrono::duration_cast(t_end - t_very_start) + .count(); + spmdm_run_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h new file mode 100644 index 0000000..30160d1 --- /dev/null +++ b/src/GenerateKernel.h @@ -0,0 +1,154 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * @brief AVX2/AVX512 JIT assembly code generator. + * @tparam TA Type of matrix A. + * @tparam TB Type of matrix B. + * @tparam TC Type of matrix C. + * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or + * 32-bit (std::int32_t) accumulation. + */ +template +class CodeGenBase { + public: + using jit_micro_kernel_fp = void (*)( + TA* bufferA, + TB* bufferB, + TB* b_pf, + TC* bufferC, + int kc, + int ldc); + + /** + * @brief Constructor for initializing AVX2/AVX512 registers. + */ + CodeGenBase() + : CRegs_avx2_{x86::ymm0, + x86::ymm1, + x86::ymm2, + x86::ymm3, + x86::ymm4, + x86::ymm5, + x86::ymm6, + x86::ymm7, + x86::ymm8, + x86::ymm9, + x86::ymm10, + x86::ymm11}, + CRegs_avx512_{ + x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4, + x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9, + x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14, + x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, + x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, + x86::zmm25, x86::zmm26, x86::zmm27, + } { + // vector width in bits + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + vectorWidth_ = 512; + } else if (cpuinfo_has_x86_avx2()) { + vectorWidth_ = 256; + } else { + // TODO: Have default path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + // vector width in elements + VLEN_ = vectorWidth_ / 8 * sizeof(TA); + } + + /** + * @brief Get or Create the instructions for macro-kernel. + * + * If the problem size (mc, nc) and accumulation flag (accum) can be found in + * the code cache (a hash map), then get the macro-kernel instructions + * directly from it. Otherwise, create the instructions for macro-kernel, and + * store that into the code cache. + */ + template + jit_micro_kernel_fp + getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc, int32_t ldc); + + /** + * @brief Generate instructions for initializing the C registers to 0. + */ + template + void initCRegs( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign = 4); + + /** + * @brief Generate instructions for computing block in the rank-k update. + */ + template + void genComputeBlock( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign = 4); + + /** + * @brief Generate instructions for storing the C registers back to the + * memory. + */ + template + void storeCRegs( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign = 4); + + private: + asmjit::X86Ymm + CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. + asmjit::X86Zmm + CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. + int vectorWidth_; ///< Vector width in bits. + int VLEN_; ///< Vector width in elements. + static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. + static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + static thread_local std::map, jit_micro_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. +}; + +template +thread_local asmjit::JitRuntime CodeGenBase::rt_; + +template +thread_local asmjit::CodeHolder CodeGenBase::code_; + +template +thread_local std::map< + std::tuple, + typename CodeGenBase::jit_micro_kernel_fp> + CodeGenBase::codeCache_; + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc new file mode 100644 index 0000000..2ffe3ab --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -0,0 +1,292 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX2 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Generate AVX2 instructions for computing block in the rank-k update of 16-bit + * Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Ymm AReg = x86::ymm12; + + asmjit::X86Ymm tmpReg = x86::ymm14; + + for (int i = 0; i < rowRegs; ++i) { + // broadcast A + a->vpbroadcastw( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + for (int j = 0; j < colRegs; ++j) { + a->vpmaddubsw( + tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vpaddsw( + CRegs_avx2_[i * leadingDimCRegAssign + j], + tmpReg, + CRegs_avx2_[i * leadingDimCRegAssign + j]); + // Prefetching is hurting performance in some cases + // because prefetch instructions itself consumes a slot + // in pipeline issue thus slowing down the kernel. + // if((i == rowRegs - 1) && j % 2 == 0){ + // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); + //} + } + } +} + +/** + * Generate AVX2 instructions for storing the C registers back to the memory in + * 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + asmjit::X86Xmm extractDest128 = x86::xmm15; + asmjit::X86Ymm extractDest256 = x86::ymm15; + + for (int i = 0; i < rowRegs; ++i) { + a->imul(C_Offset, ldcReg, i * sizeof(int32_t)); + for (int j = 0; j < colRegs; ++j) { + for (int idx = 0; idx < 2; ++idx) { + a->vextracti128( + extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx); + a->vpmovsxwd(extractDest256, extractDest128); + asmjit::X86Mem destAddr = x86::dword_ptr( + a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); + if (accum) { + a->vpaddd(extractDest256, extractDest256, destAddr); + } + a->vmovups(destAddr, extractDest256); + } + } + } +} + +/** + * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + // constexpr int nRegBlockSize = + // PackingTraits::NR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc new file mode 100644 index 0000000..e613cf1 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16_avx512.cc @@ -0,0 +1,295 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm29; + + asmjit::X86Zmm tmpReg = x86::zmm30; + + for (int i = 0; i < rowRegs; ++i) { + // broadcast A + a->vpbroadcastw( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + for (int j = 0; j < colRegs; ++j) { + a->vpmaddubsw( + tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vpaddsw( + CRegs_avx512_[i * leadingDimCRegAssign + j], + tmpReg, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + // Prefetching is hurting performance in some cases + // because prefetch instructions itself consumes a slot + // in pipeline issue thus slowing down the kernel. + // if((i == rowRegs - 1) && j % 2 == 0){ + // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); + //} + } + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + asmjit::X86Ymm extractDest256 = x86::ymm31; + asmjit::X86Zmm extractDest512 = x86::zmm31; + + for (int i = 0; i < rowRegs; ++i) { + a->imul(C_Offset, ldcReg, i * sizeof(int32_t)); + for (int j = 0; j < colRegs; ++j) { + for (int idx = 0; idx < 2; ++idx) { + a->vextracti32x8( + extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); + a->vpmovsxwd(extractDest512, extractDest256); + asmjit::X86Mem destAddr = x86::dword_ptr( + a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); + if (accum) { + a->vpaddd(extractDest512, extractDest512, destAddr); + } + a->vmovups(destAddr, extractDest512); + } + } + } +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + // constexpr int nRegBlockSize = + // PackingTraits::NR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc new file mode 100644 index 0000000..dc8c6d3 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -0,0 +1,310 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX2 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx2_[i * leadingDimCReg + j], + CRegs_avx2_[i * leadingDimCReg + j], + CRegs_avx2_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX2 instructions for computing block in the rank-k update of 32-bit + * Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Ymm AReg = x86::ymm12; + + // used for matrix B + asmjit::X86Ymm BReg = x86::ymm13; + + // Contains 16-bit 1s + asmjit::X86Ymm oneReg = x86::ymm15; + + // temporary register + asmjit::X86Ymm res1 = x86::ymm14; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd( + CRegs_avx2_[i * leadingDimCRegAssign + j], + res1, + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX2 instructions for storing the C registers back to the memory in + * 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + // temp register + asmjit::X86Ymm tmpReg = x86::ymm14; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)), + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Get or Create the AVX2 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); + + asmjit::X86Ymm oneReg = x86::ymm15; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); + a->mov(C_Offset, 0); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + // a->add(B_pf, 32*sizeof(float)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc new file mode 100644 index 0000000..5cd5684 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32_avx512.cc @@ -0,0 +1,312 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm31; + + // used for matrix B + asmjit::X86Zmm BReg = x86::zmm30; + + // Contains 16-bit 1s + asmjit::X86Zmm oneReg = x86::zmm29; + + // temporary register + asmjit::X86Zmm res1 = x86::zmm28; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + res1, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + // temp register + asmjit::X86Zmm tmpReg = x86::zmm28; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); + + asmjit::X86Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); + a->mov(C_Offset, 0); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + // a->add(B_pf, 32*sizeof(float)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc new file mode 100644 index 0000000..543d99b --- /dev/null +++ b/src/PackAMatrix.cc @@ -0,0 +1,165 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackAMatrix::PackAMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const T* smat, + int32_t ld, + inpType* pmat, + int32_t groups, + accT zero_pt) + : PackMatrix, T, accT>(nRow, nCol, pmat, zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + if (!pmat) { + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } +} + +template +void PackAMatrix::pack(const block_type_t& block) { + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + + BaseType::packedBlock(block_p); + bool tr = (trans_ == matrix_op_t::Transpose); + T* out = BaseType::getBuf(); + if (tr) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = smat_[i + ld_ * j]; + out[addr(i, j) - addr(block.row_start, block.col_start)] = val; + } + // zero fill + // Please note that we zero fill, not zero_pt fill, because for + // requantization original, i.e., not padded, dimensions are used. If we + // were to use padded dimensions for requantization, we would zero_pt + // fill. + // For example, consider the following dot product: + // A = .3(5-15), .3(20-15) //.3 is scale and 15 is zero_pt + // B = .4(1+10), .4(4+10) // .4 is scale and -10 is zero_pt + // + // numElements(A) = 2 and numElements(B) = 2 + // + // Dot product is (real): -3*4.4+1.5*5.6 = -4.8 + // Dot product is (quantized): 5*1+20*4 = 85 + // + // requantization: .3*.4(85 - (5+20)*(-10) - (1+4)*(15) + + // numElements(A)*(15)(-10)) = -4.8 + // + // In the above adding one more element zero in the quantized domain, + // i.e., the quantized vectors become: + // A_q = 5, 20, 0 + // B_q = 1, 4, 0 + // + // and requantization with numElements(A) = 2 will produce the same + // answer (-4.8). + // + // Also in the above adding one more element zero_pt in the quantized + // domain, i.e., the quantized vectors become: + // A_q = 5, 20, 15 + // B_q = 1, 4, -10 + // + // and requantization with numElements(A) = 3 will produce the same + // answer (-4.8). + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; + ++j) { + out[addr(i, j) - addr(block.row_start, block.col_start)] = 0; + } + } + } else { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int buf_idx = i - block.row_start; + memcpy( + out + buf_idx * BaseType::blockColSize(), + smat_ + i * ld_ + block.col_start, + block.col_size * sizeof(T)); + // zero fill + for (int j = block.col_size; j < block_p.col_size; ++j) { + out[buf_idx * BaseType::blockColSize() + j] = 0; + } + } + } +} + +template +int32_t PackAMatrix::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template +void PackAMatrix::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template class PackAMatrix; +template class PackAMatrix; +} // namespace fbgemm2 diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc new file mode 100644 index 0000000..7012289 --- /dev/null +++ b/src/PackAWithIm2Col.cc @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackAWithIm2Col::PackAWithIm2Col( + const conv_param_t& conv_p, + const T* sdata, + inpType* pmat, + int32_t zero_pt, + int32_t* row_offset) + : PackMatrix, T, accT>( + conv_p.MB * conv_p.OH * conv_p.OW, + conv_p.KH * conv_p.KW * conv_p.IC, + pmat, + zero_pt), + conv_p_(conv_p), + sdata_(sdata) { + assert(conv_p.G == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = static_cast( + aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T))); + } + if (row_offset) { + row_offset_ = row_offset; + } else { + rowOffsetAllocatedHere = true; + row_offset_ = static_cast( + aligned_alloc(64, BaseType::brow_ * sizeof(int32_t))); + } +} + +template +void PackAWithIm2Col::pack(const block_type_t& block) { + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + + BaseType::packedBlock(block_p); + T* out = BaseType::getBuf(); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int n = i / (conv_p_.OH * conv_p_.OW); + int hw = i % (conv_p_.OH * conv_p_.OW); + int w = hw % conv_p_.OW; + int h = hw / conv_p_.OW; + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + int c = j % conv_p_.IC; + int rs = j / conv_p_.IC; + int s = rs % conv_p_.KW; + int r = rs / conv_p_.KW; + + int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s; + int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r; + // Please note that padding for convolution should be filled with zero_pt + if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = BaseType::zeroPoint(); + } else { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = sdata_ + [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c]; + } + } + // zero fill + // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; + ++j) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = 0; + } + } +} + +template +void PackAWithIm2Col::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[ r * BaseType::blockColSize() + c ]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +int PackAWithIm2Col::rowOffsetBufferSize() { + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + return PackingTraits::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits::MCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithIm2Col; +template class PackAWithIm2Col; +} // namespace fbgemm2 diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc new file mode 100644 index 0000000..30d94f8 --- /dev/null +++ b/src/PackBMatrix.cc @@ -0,0 +1,144 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackBMatrix::PackBMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const T* smat, + int32_t ld, + inpType* pmat, + int32_t groups, + accT zero_pt) + : PackMatrix, T, accT>(nRow, nCol, pmat, zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits::KCB; + BaseType::bcol_ = PackingTraits::NCB; + row_interleave_ = + PackingTraits::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits::KCB; + BaseType::bcol_ = PackingTraits::NCB; + row_interleave_ = PackingTraits::ROW_INTERLEAVE; + } else { + // Error + assert(0 && "unknown architecure"); + } + block_type_t block{0, BaseType::numRows(), 0, BaseType::numCols()}; + BaseType::packedBlock(block); + if (!pmat) { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = (T*)aligned_alloc( + 64, + BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() * + BaseType::bcol_ * sizeof(T)); + } + pack(block); +} + +template +void PackBMatrix::pack(const block_type_t& block) { + assert((BaseType::blockRowSize() % row_interleave_) == 0); + + BaseType::packedBlock(block); + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = tr ? smat_[i + ld_ * j] : smat_[i * ld_ + j]; + out[addr(i, j) - addr(block.row_start, block.col_start)] = + tconv(val, out[addr(i, j)]); + } + } + // fill the remaining with zero. + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int i = block.row_start + block.row_size; + i < (block.row_start + block.row_size + row_interleave_ - 1) / + row_interleave_ * row_interleave_; + ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; j++) { + out[addr(i, j) - addr(block.row_start, block.col_start)] = + tconv(0, out[addr(i, j)]); + } + } +} + +template +int32_t PackBMatrix::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = (r % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + (c % BaseType::blockColSize()) * row_interleave_ + r % row_interleave_; + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template +void PackBMatrix::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + std::cout << "block size:" + << "[" << BaseType::blockRowSize() << ", " + << BaseType::blockColSize() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; + } + } +} + +template class PackBMatrix; +template class PackBMatrix; +} // namespace fbgemm2 diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc new file mode 100644 index 0000000..85000ac --- /dev/null +++ b/src/PackMatrix.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackMatrix::PackMatrix( + int32_t rows, + int32_t cols, + inpType* buf, + int32_t zero_pt) + : buf_(buf), nrows_(rows), ncols_(cols), zero_pt_(zero_pt) { + bufAllocatedHere_ = false; + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template +int PackMatrix::packedBufferSize(int rows, int cols) { + if (cpuinfo_has_x86_avx512f()) { + if (isA) { + return PackingTraits::MCB * + PackingTraits::KCB; + } else { + int rowBlock = PackingTraits::KCB; + int colBlock = PackingTraits::NCB; + return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * + (((cols + colBlock - 1) / colBlock) * colBlock); + } + } else if (cpuinfo_has_x86_avx2()) { + if (isA) { + return PackingTraits::MCB * + PackingTraits::KCB; + } else { + int rowBlock = PackingTraits::KCB; + int colBlock = PackingTraits::NCB; + return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * + (((cols + colBlock - 1) / colBlock) * colBlock); + } + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + return -1; +} + +// int32 accumulation +template class PackMatrix, uint8_t, int32_t>; + +template class PackMatrix< + PackAWithRowOffset, + uint8_t, + int32_t>; + +template class PackMatrix, uint8_t, int32_t>; + +template class PackMatrix< + PackAWithQuantRowOffset, + uint8_t, + int32_t>; + +template class PackMatrix, int8_t, int32_t>; + +// int16 accumulation +template class PackMatrix, uint8_t, int16_t>; + +template class PackMatrix< + PackAWithRowOffset, + uint8_t, + int16_t>; + +template class PackMatrix, uint8_t, int16_t>; + +template class PackMatrix, int8_t, int16_t>; +} // namespace fbgemm2 diff --git a/src/PackWithQuantRowOffset.cc b/src/PackWithQuantRowOffset.cc new file mode 100644 index 0000000..74eaade --- /dev/null +++ b/src/PackWithQuantRowOffset.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackAWithQuantRowOffset::PackAWithQuantRowOffset( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const float* smat, + int32_t ld, + inpType* pmat, + float scale, + int32_t zero_pt, + int32_t groups, + int32_t* row_offset) + : PackMatrix, T, accT>( + nRow, + nCol, + pmat, + zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + scale_(scale), + G_(groups), + row_offset_(row_offset) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + rowOffsetAllocatedHere = false; + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } + if (!row_offset_) { + rowOffsetAllocatedHere = true; + row_offset_ = reinterpret_cast( + aligned_alloc(64, BaseType::brow_ * sizeof(accT))); + } +} + +template +void PackAWithQuantRowOffset::pack(const block_type_t& block) { + assert(block.row_start % BaseType::blockRowSize() == 0); + assert(block.col_start % BaseType::blockColSize() == 0); + assert(block.row_size <= BaseType::blockRowSize()); + assert(block.col_size <= BaseType::blockColSize()); + + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + assert(block_p.col_size <= BaseType::blockColSize()); + BaseType::packedBlock(block_p); + + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + // accumulate into row offset? + bool row_offset_acc = (block.col_start != 0); + int32_t* row_offset_buf = getRowOffsetBuffer(); + + float smat_transposed[block.row_size * block.col_size]; + if (tr) { + transpose_simd( + block.col_size, + block.row_size, + smat_ + block.col_start * ld_ + block.row_start, + ld_, + smat_transposed, + block.col_size); + } + const float* smat_temp = + tr ? smat_transposed : smat_ + block.row_start * ld_ + block.col_start; + int32_t ld_temp = tr ? block.col_size : ld_; + +#if defined(__AVX2__) && defined(__FMA__) + constexpr int VLEN = 8; + __m256 inverse_scale_v = _mm256_set1_ps(1.0f / scale_); + __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00); + __m256i permute_mask_v = _mm256_set_epi32( + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); +#endif + + for (int i = 0; i < block.row_size; ++i) { + int32_t row_sum = row_offset_acc ? row_offset_buf[i] : 0; + int j = 0; +#if defined(__AVX2__) && defined(__FMA__) + static_assert( + std::is_same::value, + "PackAWithQuantRowOffset::pack only works for T == uint8_t"); + for (; j < block.col_size / VLEN * VLEN; j += VLEN) { + __m256 val_v = _mm256_loadu_ps(smat_temp + i * ld_temp + j); + __m256 transformed_v = _mm256_fmadd_ps( + val_v, inverse_scale_v, _mm256_set1_ps(BaseType::zeroPoint())); + __m256 clipped_v = _mm256_max_ps( + _mm256_set1_ps(std::numeric_limits::min()), + _mm256_min_ps( + transformed_v, + _mm256_set1_ps(std::numeric_limits::max()))); + __m256i res_v = _mm256_cvtps_epi32(clipped_v); + + // An instruction sequence to save 8 32-bit integers as 8 8-bit integers + res_v = _mm256_shuffle_epi8(res_v, shuffle_mask_v); + res_v = _mm256_permutevar8x32_epi32(res_v, permute_mask_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(out + i * BaseType::blockColSize() + j), + _mm256_castsi256_si128(res_v)); + + for (int j2 = j; j2 < j + VLEN; ++j2) { + row_sum += out[i * BaseType::blockColSize() + j2]; + } + } +#endif + for (; j < block.col_size; ++j) { + float val = smat_temp[i * ld_temp + j]; + float transformed = val / scale_ + BaseType::zeroPoint(); + float clipped = std::min( + std::max(transformed, std::numeric_limits::min()), + std::numeric_limits::max()); + T res = round(clipped); + row_sum += res; + out[i * BaseType::blockColSize() + j] = res; + } + // zero fill + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (; j < block_p.col_size; ++j) { + out[i * BaseType::blockColSize() + j] = 0; + } + row_offset_buf[i] = row_sum; + } +} + +template +int32_t PackAWithQuantRowOffset::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template +void PackAWithQuantRowOffset::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +int PackAWithQuantRowOffset::rowOffsetBufferSize() { + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + return PackingTraits::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits::MCB; + } else { + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithQuantRowOffset; + +} // namespace fbgemm2 diff --git a/src/PackWithRowOffset.cc b/src/PackWithRowOffset.cc new file mode 100644 index 0000000..8722723 --- /dev/null +++ b/src/PackWithRowOffset.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template +PackAWithRowOffset::PackAWithRowOffset( + matrix_op_t trans, + uint32_t nRow, + uint32_t nCol, + const T* smat, + uint32_t ld, + inpType* pmat, + uint32_t groups, + int32_t zero_pt, + int32_t* row_offset) + : PackMatrix, T, accT>( + nRow, + nCol, + pmat, + zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups), + row_offset_(row_offset) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + rowOffsetAllocatedHere = false; + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + //TODO: Have default slower path + assert(0 && "unknown architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } + if (!row_offset_) { + rowOffsetAllocatedHere = true; + row_offset_ = static_cast(aligned_alloc(64, + BaseType::brow_ * sizeof(int32_t))); + } +} + +template +void PackAWithRowOffset::pack(const block_type_t& block) { + assert(block.row_start % BaseType::blockRowSize() == 0); + assert(block.col_start % BaseType::blockColSize() == 0); + assert(block.row_size <= BaseType::blockRowSize()); + assert(block.col_size <= BaseType::blockColSize()); + + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + assert(block_p.col_size <= BaseType::blockColSize()); + BaseType::packedBlock(block_p); + + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + // accumulate into row offset? + bool row_offset_acc = (block.col_start != 0); + int32_t* row_offset_buf = getRowOffsetBuffer(); + if (tr) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int32_t row_sum = row_offset_acc ? + row_offset_buf[i - block.row_start] : 0; + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = smat_[i + ld_ * j]; + row_sum += val; + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = val; + } + row_offset_buf[i - block.row_start] = row_sum; + // zero fill + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; ++j) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = 0; + } + } + } else { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int buf_idx = i - block.row_start; + memcpy( + out + buf_idx * BaseType::blockColSize(), + smat_ + i * ld_ + block.col_start, + block.col_size * sizeof(T)); + // zero fill + for (int j = block.col_size; j < block_p.col_size; ++j) { + out[buf_idx * BaseType::blockColSize() + j] = 0; + } + int32_t row_sum = row_offset_acc ? + row_offset_buf[i - block.row_start] : 0; + __m256i sum_v = _mm256_setzero_si256(); + __m256i one_epi16_v = _mm256_set1_epi16(1); + __m256i one_epi8_v = _mm256_set1_epi8(1); + for (int j = block.col_start; + j < block.col_start + block.col_size / 32 * 32; + j += 32) { + __m256i src_v = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(smat_ + i * ld_ + j)); + sum_v = _mm256_add_epi32( + sum_v, + _mm256_madd_epi16( + _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v)); + } + for (int j = block.col_start + block.col_size / 32 * 32; + j < block.col_start + block.col_size; + ++j) { + row_sum += smat_[i * ld_ + j]; + } + alignas(64) std::array temp; + _mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v); + for (int k = 0; k < 8; ++k) { + row_sum += temp[k]; + } + row_offset_buf[i - block.row_start] = row_sum; + } + } +} + +template +int32_t PackAWithRowOffset::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template +void PackAWithRowOffset::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +int PackAWithRowOffset::rowOffsetBufferSize() { + if(cpuinfo_initialize()){ + if (cpuinfo_has_x86_avx512f()) { + return PackingTraits::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits::MCB; + } else { + //TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithRowOffset; +template class PackAWithRowOffset; + +} // namespace fbgemm2 diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc new file mode 100644 index 0000000..9aedc88 --- /dev/null +++ b/src/RefImplementations.cc @@ -0,0 +1,608 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "RefImplementations.h" + +#include +#include + +using namespace std; + +namespace fbgemm2 { + +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const int32_t* inp, + uint8_t* out, + int32_t C_multiplier, + int32_t C_right_shift, + int32_t C_zero_point, + int32_t A_zero_point, + int32_t B_zero_point, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu) { + int64_t nudge = 1ll << std::max(0, C_right_shift - 1); + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + int32_t raw = inp[i * ld + j]; + raw -= A_zero_point * col_offsets[j]; + raw -= B_zero_point * row_offsets[i]; + if (bias) { + raw += bias[j]; + } + + int64_t ab_64 = + static_cast(raw) * static_cast(C_multiplier); + int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point; + + out[i * ld + j] = std::max( + fuse_relu ? static_cast(C_zero_point) : 0l, + std::min(255l, rounded)); + } + } +} + +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const int32_t* inp, + uint8_t* out, + float C_multiplier, + int32_t C_zero_point, + int32_t A_zero_point, + int32_t B_zero_point, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + int32_t raw = inp[i * ld + j]; + raw -= A_zero_point * col_offsets[j]; + raw -= B_zero_point * row_offsets[i]; + if (bias) { + raw += bias[j]; + } + + float result = raw * C_multiplier; + long rounded = lrintf(result) + C_zero_point; + out[i * ld + j] = std::max( + fuse_relu ? static_cast(C_zero_point) : 0l, + std::min(255l, rounded)); + } + } +} + +void matmul_u8i8acc32_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const uint8_t* Aint8, + const int8_t* Bint8, + int32_t* Cint32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += static_cast(Aint8[i * lda + k]) * + static_cast(Bint8[k * ldb + j]); + } + Cint32[i * ldc + j] = sum; + } + } +} + +void matmul_u8i8acc16_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int brow, + const uint8_t* Aint8, + const int8_t* Bint8, + int32_t* Cint32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + int32_t sum = 0, sum_32bit = 0; + for (int k = 0; k < K; k += 2) { + int a0 = Aint8[i * lda + k]; + int b0 = Bint8[k * ldb + j]; + int a1 = 0, b1 = 0; + if (k + 1 < K) { + a1 = Aint8[i * lda + k + 1]; + b1 = Bint8[(k + 1) * ldb + j]; + } + sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1)); + if ((k % brow) == (brow - 2)) { + sum_32bit += sum; + sum = 0; + } + } + Cint32[i * ldc + j] = sum_32bit + sum; + } + } +} + +void matmul_fp_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const float* Afp32, + const float* Bfp32, + float* Cfp32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + float sum = 0; + for (int k = 0; k < K; ++k) { + sum += Afp32[i * lda + k] * Bfp32[k * ldb + j]; + } + Cfp32[i * ldc + j] = sum; + } + } +} + +void row_offsets_u8acc32_ref( + int M, + int K, + int ld, + const uint8_t* Aint8, + int32_t* row_offsets) { + // row offset + for (int i = 0; i < M; ++i) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += static_cast(Aint8[i * ld + k]); + } + row_offsets[i] = sum; + } +} + +void col_offsets_with_zero_pt_s8acc32_ref( + int K, + int N, + int ld, + const int8_t* Bint8, + int32_t B_zero_point, + int32_t* col_offsets) { + for (int j = 0; j < N; ++j) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += Bint8[k * ld + j]; + } + col_offsets[j] = sum - B_zero_point * K; + } +} + +void spmdm_ref( + int M, + const uint8_t* A, + int lda, + fbgemm2::CompressedSparseColumn& B, + bool accumulation, + int32_t* C, + int ldc) { + int N = B.NumOfCols(); + if (!accumulation) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + C[i * ldc + j] = 0; + } + } + } + for (int j = 0; j < N; ++j) { + for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) { + int row = B.RowIdx()[k]; + int w = B.Values()[k]; + for (int i = 0; i < M; ++i) { + C[i * ldc + j] += A[i * lda + row] * w; + } + } + } // for each column of B +} + +int32_t clip_16bit(int32_t x) { + if (x > std::numeric_limits::max()) { + return std::min(std::numeric_limits::max(), x); + } else if (x < std::numeric_limits::min()) { + return std::max(std::numeric_limits::min(), x); + } else { + return x; + } +} + +void im2col_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + std::uint8_t* Ao) { + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int r = 0; r < conv_p.KH; ++r) { + int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; + for (int s = 0; s < conv_p.KW; ++s) { + int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; + for (int c = 0; c < conv_p.IC; ++c) { + // Ai: NHWC: NH_0W_0 x C_0 + std::uint8_t val = + h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW + ? A_zero_point + : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC + + c]; + // Ao: NHWC: NH_1W_1 x RSC_0 + Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * + conv_p.KW + + s) * + conv_p.IC + + c] = val; + } // for each c + } // for each s + } // for each r + } // for each w + } // for each h + } // for each n +} + +void conv_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* B, + std::int32_t* C) { + // filters are assumed to be in RSCK format + assert(conv_p.G == 1 && "Groups != 1 not supported yet"); + + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int sum = 0; + for (int r = 0; r < conv_p.KH; ++r) { + int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; + for (int s = 0; s < conv_p.KW; ++s) { + int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; + for (int c = 0; c < conv_p.IC; ++c) { + int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 || + w_in >= conv_p.IW + ? A_zero_point + : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * + conv_p.IC + + c]; + int b = + B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k]; + sum += a * b; + } // for each c + } // for each s + } // for each r + C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum; + } // for each k + } // for each w + } // for each h + } // for each n +} + +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int8_t* B, + int32_t* C) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + int b = B[(k * R + r) * S + s]; + sum += a * b; + } + } + C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n +}; + +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + vector C_int32(N * H_OUT * W_OUT * K); + depthwise_3x3_pad_1_ref( + N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); + + vector row_offsets(N * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n + + for (int i = 0; i < N * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier, + C_zero_point, + A_zero_point, + B_zero_point, + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +void depthwise_3x3_per_channel_quantization_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + vector C_int32(N * H_OUT * W_OUT * K); + depthwise_3x3_pad_1_ref( + N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); + + vector row_offsets(N * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n + + for (int i = 0; i < N * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier[k], + C_zero_point, + A_zero_point, + B_zero_point[k], + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int8_t* B, + int32_t* C) { + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int k_t = 0; k_t < K_T; ++k_t) { + int t_in = -PAD_P + t * stride_t + k_t; + for (int k_h = 0; k_h < K_H; ++k_h) { + int h_in = -PAD_T + h * stride_h + k_h; + for (int k_w = 0; k_w < K_W; ++k_w) { + int w_in = -PAD_L + w * stride_w + k_w; + int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || + w_in < 0 || w_in >= W + ? A_zero_point + : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; + int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w]; + sum += a * b; + } + } + } + C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } // w + } // h + } // t + } // for each n +}; + +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + vector C_int32(N * T_OUT * H_OUT * W_OUT * K); + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B, + C_int32.data()); + + vector row_offsets(N * T_OUT * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int k_t = 0; k_t < K_T; ++k_t) { + int t_in = -PAD_P + t * stride_t + k_t; + for (int k_h = 0; k_h < K_H; ++k_h) { + int h_in = -PAD_T + h * stride_h + k_h; + for (int k_w = 0; k_w < K_W; ++k_w) { + int w_in = -PAD_L + w * stride_w + k_w; + int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || + w_in < 0 || w_in >= W + ? A_zero_point + : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + } + row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = + sum; + } + } // w + } // h + } // t + } // for each n + + for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier, + C_zero_point, + A_zero_point, + B_zero_point, + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +} // namespace fbgemm2 diff --git a/src/RefImplementations.h b/src/RefImplementations.h new file mode 100644 index 0000000..e9eaeed --- /dev/null +++ b/src/RefImplementations.h @@ -0,0 +1,268 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include "fbgemm/ConvUtils.h" +#include "fbgemm/FbgemmI8Spmdm.h" + +namespace fbgemm2 { + +/** + * @brief Reference implementation of requantization step. + * int32 multiplier + * @params bias can be nullptr + */ +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const std::int32_t* inp, + std::uint8_t* out, + std::int32_t C_multiplier, + std::int32_t C_right_shift, + std::int32_t C_zero_point, + std::int32_t A_zero_point, + std::int32_t B_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false); + +/** + * @brief Reference implementation of requantization step. + * float multiplier + * @params bias can be nullptr + */ +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const std::int32_t* inp, + std::uint8_t* out, + float C_multiplier, + std::int32_t C_zero_point, + std::int32_t A_zero_point, + std::int32_t B_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false); + +/** + * @brief Reference implementation of matrix multiply with uint8 for A, + * int8 for B, and 32-bit accumulation. + */ +void matmul_u8i8acc32_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const std::uint8_t* Aint8, + const std::int8_t* Bint8, + std::int32_t* Cint32); + +/** + * @brief Reference implementation of matrix multiply with uint 8 for A, + * int8 for B, and 16-bit accumulation. + */ +void matmul_u8i8acc16_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int brow, + const std::uint8_t* Aint8, + const std::int8_t* Bint8, + std::int32_t* Cint32); + +/** + * @brief Reference implementation of matrix multiply with fp32 (single + * precision) floating point number. + */ +void matmul_fp_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const float* Afp32, + const float* Bfp32, + float* Cfp32); + +/** + * @brief Reference implementation to compute row_offsets (sums of rows of A). + */ +void row_offsets_u8acc32_ref( + int M, + int K, + int ld, + const std::uint8_t* Aint8, + std::int32_t* row_offsets); + +/** + * @brief Reference implementation to compute adjusted col_offsets (sum of + * columns of B and adjusted with B_zero_point) + */ +void col_offsets_with_zero_pt_s8acc32_ref( + int K, + int N, + int ld, + const std::int8_t* Bint8, + std::int32_t B_zero_point, + std::int32_t* col_offsets); + +/** + * @brief Reference implementation of SPMDM (sparse matrix times dense matrix). + */ +void spmdm_ref( + int M, + const std::uint8_t* A, + int lda, + CompressedSparseColumn& B, + bool accumulation, + std::int32_t* C, + int ldc); + +/* + * @brief Trim a 32-bit integer to a 16-bit integer. + */ +int32_t clip_16bit(int32_t x); + +/* + * @brief Reference implementation of convolution operation. + * The activations A are assumed to be in NHiWiC format. + * The filters B are assumed to be in RSCK format. + * The output C is assumed to be in NHoWoC format. + */ +void conv_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of im2col operation. + * The input A is assumed to be in NHiWiC format. + * The output A is assumed to be in NHoWoRSC format. + */ +void im2col_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + std::uint8_t* Ao); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1. + */ +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1, followed by requantization. (the same scaling factors and + * zero points for each channel). + */ +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const std::int8_t* B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1, followed by requantization. (different scaling factors + * and zero points for each channel). + */ +void depthwise_3x3_per_channel_quantization_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const std::int8_t* B, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +/* + * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 + * filter and padding size 1. + */ +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 + * filter and padding size 1, followed by requantization. + */ +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const std::int8_t* B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +} // namespace fbgemm2 diff --git a/src/Utils.cc b/src/Utils.cc new file mode 100644 index 0000000..10ab469 --- /dev/null +++ b/src/Utils.cc @@ -0,0 +1,357 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/Utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fbgemm2 { + +/** + * @brief Compare the reference and test result matrix to check the correctness. + * @param ref The buffer for the reference result matrix. + * @param test The buffer for the test result matrix. + * @param m The height of the reference and test result matrix. + * @param n The width of the reference and test result matrix. + * @param ld The leading dimension of the reference and test result matrix. + * @param max_mismatches_to_report The maximum number of tolerable mismatches to + * report. + * @param atol The tolerable error. + * @retval false If the number of mismatches for reference and test result + * matrix exceeds max_mismatches_to_report. + * @retval true If the number of mismatches for reference and test result matrix + * is tolerable. + */ +template +int compare_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol /*=1e-3*/) { + size_t mismatches = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + T reference = ref[i * ld + j], actual = test[i * ld + j]; + if (std::abs(reference - actual) > atol) { + std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl; + if (std::is_integral::value) { + std::cout << "\t reference:" << static_cast(reference) + << " test:" << static_cast(actual) << std::endl; + } else { + std::cout << "\t reference:" << reference << " test:" << actual + << std::endl; + } + + mismatches++; + if (mismatches > max_mismatches_to_report) { + return 1; + } + } + } + } + return 0; +} + + +/** + * @brief Print the matrix. + * @param op Transpose type of the matrix. + * @param R The height of the matrix. + * @param C The width of the matrix. + * @param ld The leading dimension of the matrix. + * @param name The prefix string before printing the matrix. + */ +template +void printMatrix( + matrix_op_t op, + const T* inp, + size_t R, + size_t C, + size_t ld, + std::string name) { + // R: number of rows in op(inp) + // C: number of cols in op(inp) + // ld: leading dimension in inp + std::cout << name << ":" + << "[" << R << ", " << C << "]" << std::endl; + bool tr = (op == matrix_op_t::Transpose); + for (auto r = 0; r < R; ++r) { + for (auto c = 0; c < C; ++c) { + T res = tr ? inp[c * ld + r] : inp[r * ld + c]; + if (std::is_integral::value) { + std::cout << std::setw(5) << static_cast(res) << " "; + } else { + std::cout << std::setw(5) << res << " "; + } + } + std::cout << std::endl; + } +} + +template int compare_buffers( + const float* ref, + const float* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template int compare_buffers( + const int32_t* ref, + const int32_t* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template int compare_buffers( + const uint8_t* ref, + const uint8_t* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template void printMatrix( + matrix_op_t op, + const float* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix( + matrix_op_t op, + const int8_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix( + matrix_op_t op, + const uint8_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix( + matrix_op_t op, + const int32_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); + + +/** + * @brief Reference implementation of matrix transposition: B = A^T. + * @param M The height of the matrix. + * @param N The width of the matrix. + * @param src The memory buffer of the source matrix A. + * @param ld_src The leading dimension of the source matrix A. + * @param dst The memory buffer of the destination matrix B. + * @param ld_dst The leading dimension of the destination matrix B. + */ +inline void transpose_ref( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + dst[i + j * ld_dst] = src[i * ld_src + j]; + } + } +} + +inline void +transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 + // b : b0 b1 b2 b3 + // c : c0 c1 c2 c3 + // d : d0 d1 d2 d3 + __m128 a = _mm_loadu_ps(&src[0 * ld_src]); + __m128 b = _mm_loadu_ps(&src[1 * ld_src]); + __m128 c = _mm_loadu_ps(&src[2 * ld_src]); + __m128 d = _mm_loadu_ps(&src[3 * ld_src]); + + // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE + // a : a0 b0 c0 d0 + // b : a1 b1 c1 d1 + // c : a2 b2 c2 d2 + // d : a3 b3 c3 d3 + _MM_TRANSPOSE4_PS(a, b, c, d); + + // store from registers to dst + _mm_storeu_ps(&dst[0 * ld_dst], a); + _mm_storeu_ps(&dst[1 * ld_dst], b); + _mm_storeu_ps(&dst[2 * ld_dst], c); + _mm_storeu_ps(&dst[3 * ld_dst], d); +} +inline void transpose_4x4( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 4 <= M; ib += 4) { + for (jb = 0; jb + 4 <= N; jb += 4) { + transpose_kernel_4x4_sse( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +inline void transpose_kernel_8x8_avx2( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 a4 a5 a6 a7 + // b : b0 b1 b2 b3 b4 b5 b6 b7 + // c : c0 c1 c2 c3 c4 c5 c6 c7 + // d : d0 d1 d2 d3 d4 d5 d6 d7 + // e : e0 e1 e2 e3 e4 e5 e6 e7 + // f : f0 f1 f2 f3 f4 f5 f6 f7 + // g : g0 g1 g2 g3 g4 g5 g6 g7 + // h : h0 h1 h2 h3 h4 h5 h6 h7 + __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); + __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); + __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); + __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); + __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); + __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); + __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); + __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); + + __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367; + __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37; + // unpacking and interleaving 32-bit elements + // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5 + // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7 + // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5 + // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7 + // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5 + // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7 + // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5 + // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7 + ab0145 = _mm256_unpacklo_ps(a, b); + ab2367 = _mm256_unpackhi_ps(a, b); + cd0145 = _mm256_unpacklo_ps(c, d); + cd2367 = _mm256_unpackhi_ps(c, d); + ef0145 = _mm256_unpacklo_ps(e, f); + ef2367 = _mm256_unpackhi_ps(e, f); + gh0145 = _mm256_unpacklo_ps(g, h); + gh2367 = _mm256_unpackhi_ps(g, h); + + // shuffling the 32-bit elements + // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4 + // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5 + // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4 + // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5 + // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6 + // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7 + // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6 + // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7 + abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44); + abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee); + efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44); + efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee); + abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44); + abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee); + efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44); + efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee); + + // shuffling 128-bit elements + // a : a0 b0 c0 d0 e0 f0 g0 h0 + // b : a1 b1 c1 d1 e1 f1 g1 h1 + // c : a2 b2 c2 d2 e2 f2 g2 h2 + // d : a3 b3 c3 d3 e3 f3 g3 h3 + // e : a4 b4 c4 d4 e4 f4 g4 h4 + // f : a5 b5 c5 d5 e5 f5 g5 h5 + // g : a6 b6 c6 d6 e6 f6 g6 h6 + // h : a7 b7 c7 d7 e7 f7 g7 h7 + a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02); + b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02); + c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02); + d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02); + e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13); + f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13); + g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13); + h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13); + + // store from registers to dst + _mm256_storeu_ps(&dst[0 * ld_dst], a); + _mm256_storeu_ps(&dst[1 * ld_dst], b); + _mm256_storeu_ps(&dst[2 * ld_dst], c); + _mm256_storeu_ps(&dst[3 * ld_dst], d); + _mm256_storeu_ps(&dst[4 * ld_dst], e); + _mm256_storeu_ps(&dst[5 * ld_dst], f); + _mm256_storeu_ps(&dst[6 * ld_dst], g); + _mm256_storeu_ps(&dst[7 * ld_dst], h); +} + +void transpose_8x8( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 8 <= M; ib += 8) { + for (jb = 0; jb + 8 <= N; jb += 8) { + transpose_kernel_8x8_avx2( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +void transpose_simd( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // Run time CPU detection + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + transpose_16x16(M, N, src, ld_src, dst, ld_dst); + } else if (cpuinfo_has_x86_avx2()) { + transpose_8x8(M, N, src, ld_src, dst, ld_dst); + } else { + transpose_ref(M, N, src, ld_src, dst, ld_dst); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +} // namespace fbgemm2 diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc new file mode 100644 index 0000000..b6bf413 --- /dev/null +++ b/src/Utils_avx512.cc @@ -0,0 +1,243 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm/Utils.h" + +#include + +namespace fbgemm2 { + +inline void transpose_kernel_16x16_avx512( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 + __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); + __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); + __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); + __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); + __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); + __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); + __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); + __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); + __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); + __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); + __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); + __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); + __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); + __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); + __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); + __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); + + __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 + // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + // i0 ... + // i2 ... + // k0 ... + // k2 ... + // m0 ... + // m2 ... + // o0 ... + // o1 ... + ta = _mm512_unpacklo_ps(a, b); + tb = _mm512_unpackhi_ps(a, b); + tc = _mm512_unpacklo_ps(c, d); + td = _mm512_unpackhi_ps(c, d); + te = _mm512_unpacklo_ps(e, f); + tf = _mm512_unpackhi_ps(e, f); + tg = _mm512_unpacklo_ps(g, h); + th = _mm512_unpackhi_ps(g, h); + ti = _mm512_unpacklo_ps(i, j); + tj = _mm512_unpackhi_ps(i, j); + tk = _mm512_unpacklo_ps(k, l); + tl = _mm512_unpackhi_ps(k, l); + tm = _mm512_unpacklo_ps(m, n); + tn = _mm512_unpackhi_ps(m, n); + to = _mm512_unpacklo_ps(o, p); + tq = _mm512_unpackhi_ps(o, p); + + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + // i0 j0 k0 l0 ... + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // m0 n0 o0 p0 ... + // m1 n1 o1 p1 ... + // m2 n2 o2 p2 ... + // m3 n3 o3 p3 ... + a = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); + b = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); + c = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); + d = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); + e = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); + f = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); + g = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); + h = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); + i = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); + j = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); + k = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); + l = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); + m = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); + n = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); + o = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); + p = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // i4 j4 k4 l4 ... + // i5 j5 k5 l5 ... + // i6 j6 k6 l6 ... + // i7 j7 k7 l7 ... + ta = _mm512_shuffle_f32x4(a, e, 0x88); + tb = _mm512_shuffle_f32x4(b, f, 0x88); + tc = _mm512_shuffle_f32x4(c, g, 0x88); + td = _mm512_shuffle_f32x4(d, h, 0x88); + te = _mm512_shuffle_f32x4(a, e, 0xdd); + tf = _mm512_shuffle_f32x4(b, f, 0xdd); + tg = _mm512_shuffle_f32x4(c, g, 0xdd); + th = _mm512_shuffle_f32x4(d, h, 0xdd); + ti = _mm512_shuffle_f32x4(i, m, 0x88); + tj = _mm512_shuffle_f32x4(j, n, 0x88); + tk = _mm512_shuffle_f32x4(k, o, 0x88); + tl = _mm512_shuffle_f32x4(l, p, 0x88); + tm = _mm512_shuffle_f32x4(i, m, 0xdd); + tn = _mm512_shuffle_f32x4(j, n, 0xdd); + to = _mm512_shuffle_f32x4(k, o, 0xdd); + tq = _mm512_shuffle_f32x4(l, p, 0xdd); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 ... o0 + // a1 b1 c1 d1 ... o1 + // a2 b2 c2 d2 ... o2 + // a3 b3 c3 d3 ... o3 + // a4 ... + // a5 ... + // a6 ... + // a7 ... + // a8 ... + // a9 ... + // a10 ... + // a11 ... + // a12 ... + // a13 ... + // a14 ... + // a15 b15 c15 d15 ... o15 + a = _mm512_shuffle_f32x4(ta, ti, 0x88); + b = _mm512_shuffle_f32x4(tb, tj, 0x88); + c = _mm512_shuffle_f32x4(tc, tk, 0x88); + d = _mm512_shuffle_f32x4(td, tl, 0x88); + e = _mm512_shuffle_f32x4(te, tm, 0x88); + f = _mm512_shuffle_f32x4(tf, tn, 0x88); + g = _mm512_shuffle_f32x4(tg, to, 0x88); + h = _mm512_shuffle_f32x4(th, tq, 0x88); + i = _mm512_shuffle_f32x4(ta, ti, 0xdd); + j = _mm512_shuffle_f32x4(tb, tj, 0xdd); + k = _mm512_shuffle_f32x4(tc, tk, 0xdd); + l = _mm512_shuffle_f32x4(td, tl, 0xdd); + m = _mm512_shuffle_f32x4(te, tm, 0xdd); + n = _mm512_shuffle_f32x4(tf, tn, 0xdd); + o = _mm512_shuffle_f32x4(tg, to, 0xdd); + p = _mm512_shuffle_f32x4(th, tq, 0xdd); + + // store from registers to dst + _mm512_storeu_ps(&dst[0 * ld_dst], a); + _mm512_storeu_ps(&dst[1 * ld_dst], b); + _mm512_storeu_ps(&dst[2 * ld_dst], c); + _mm512_storeu_ps(&dst[3 * ld_dst], d); + _mm512_storeu_ps(&dst[4 * ld_dst], e); + _mm512_storeu_ps(&dst[5 * ld_dst], f); + _mm512_storeu_ps(&dst[6 * ld_dst], g); + _mm512_storeu_ps(&dst[7 * ld_dst], h); + _mm512_storeu_ps(&dst[8 * ld_dst], i); + _mm512_storeu_ps(&dst[9 * ld_dst], j); + _mm512_storeu_ps(&dst[10 * ld_dst], k); + _mm512_storeu_ps(&dst[11 * ld_dst], l); + _mm512_storeu_ps(&dst[12 * ld_dst], m); + _mm512_storeu_ps(&dst[13 * ld_dst], n); + _mm512_storeu_ps(&dst[14 * ld_dst], o); + _mm512_storeu_ps(&dst[15 * ld_dst], p); +} + +void transpose_16x16( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 16 <= M; ib += 16) { + for (jb = 0; jb + 16 <= N; jb += 16) { + transpose_kernel_16x16_avx512( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +} // namespace fbgemm2 diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc new file mode 100644 index 0000000..8e36c85 --- /dev/null +++ b/src/codegen_fp16fp32.cc @@ -0,0 +1,387 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +void addi(ofstream& of, string i, bool disable = false) { + if (disable == false) + of << "\"" + i + "\\t\\n\"" + "\n"; +} + +struct ISA { + unsigned avx; // 1, 2 or 3 + string name; + vector> shapes; +}; + +int main() { + bool iaca = false; + bool disable = false; + + bool fixedA = true, fixedB = true, fixedC = true; + + int eax, ebx, ecx, edx; + __cpuid(1 /* ecx = vendor string */, eax, ebx, ecx, edx); + printf("FC16 is %s supported\n", ((ecx & bit_F16C) ? " " : "not")); + + string comma = ","; + + vector isa = { + // {1, "AVX", {{4, 1, 0}, {4, 2, 0}, {4, 3, 0}, {3, 1, 0}, {3, 2, 0}, {3, + // 3, 0}}}, + { 2, "AVX2", + { { 1, 1, 0 }, + { 2, 1, 0 }, + { 3, 1, 0 }, + { 4, 1, 0 }, + { 5, 1, 0 }, + { 6, 1, 0 }, + { 7, 1, 0 }, + { 8, 1, 0 }, + { 9, 1, 0 }, + { 10, 1, 0 }, + { 11, 1, 0 }, + { 12, 1, 0 }, + { 13, 1, 0 }, + { 14, 1, 0 }, + } + } + }; + + // open all files + ofstream srcfile; + srcfile.open("FbgemmFP16UKernels.cc"); + srcfile << "#include \"FbgemmFP16UKernels.h\"\n"; + if (iaca) + srcfile << "#include \"iacaMarks.h\"\n"; + + ofstream hdrfile; + hdrfile.open("FbgemmFP16UKernels.h"); + + hdrfile << "#ifndef FBGEMM_UKERNELS\n"; + hdrfile << "#define FBGEMM_UKERNELS\n"; + hdrfile << "#include \n"; + hdrfile << "#include \n"; + hdrfile << "#include \n"; + hdrfile << "#include \"fbgemm/Types.h\"\n"; + hdrfile << "using fp16 = fbgemm2::float16;\n"; + hdrfile << "using fp32 = float;\n"; + hdrfile << "struct GemmParams {uint64_t k; float *A; const fp16 *B;\n" + "float *beta; uint64_t accum; float *C; uint64_t ldc;\n" + "uint64_t b_block_cols; uint64_t b_block_size;};\n"; + + std::map fptr_typedef; + fptr_typedef["fp16"] = ""; + fptr_typedef["fp32"] = ""; + + unsigned labelId = 0; +#if 1 + for (auto fixedA : {false}) + for (auto fixedB : {false}) + for (auto fixedC : {false}) +#else + for (auto fixedA : {true}) + for (auto fixedB : {true}) + for (auto fixedC : {true}) +#endif + for (auto s : isa) { + vector>& ukernel_shape = s.shapes; + + vector funcname(ukernel_shape.size()), + fheader(ukernel_shape.size()); + string fargs; + + for (auto fp16 : {true}) { + string B_type = ((fp16) ? "fp16" : "fp32"); + string prefix = s.name + /*"_" + B_type */ + "_" + "fA" + + to_string(fixedA) + "fB" + to_string(fixedB) + "fC" + + to_string(fixedC); + cout << "Generating code for " << s.name << " " << B_type << "\n"; + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + printf( + "shape: %d x %d * 32\n", + ukernel_shape[k][0], + ukernel_shape[k][1]); + + string p1 = "GemmParams *gp"; + + funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) + + "x" + to_string(ukernel_shape[k][1]) + "_"; + funcname[k] += prefix; + + fargs = "(" + p1 + ")"; + + fheader[k] = + "void __attribute__ ((noinline)) " + funcname[k] + fargs; + srcfile << fheader[k] << "\n"; + srcfile << "{\n"; + + unsigned last_free_ymmreg = 0; + // produce register block of C + vector> vCtile(ukernel_shape[k][0]); + for (auto r = 0; r < ukernel_shape[k][0]; r++) + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vCtile[r].push_back("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + assert(last_free_ymmreg <= 14); + + string vAtmp = "ymm" + to_string(last_free_ymmreg++); + // produce register block of B col + assert(ukernel_shape[k][1] == 1); + vector vBcol(ukernel_shape[k][1]); + + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vBcol[c] = ("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + + assert(last_free_ymmreg <= 16); + + srcfile << "asm volatile\n"; + srcfile << "(\n"; + + srcfile << "#if !defined(__clang__)" << "\n"; + addi(srcfile, "mov r14, %[gp]"); + srcfile << "#else\n"; + addi(srcfile, "mov %[gp], %%r14"); + addi(srcfile, ".intel_syntax noprefix"); + srcfile << "#endif\n"; + + srcfile << "\n// Copy parameters\n"; + srcfile << "// k\n"; + addi(srcfile, "mov r8, [r14 + 0]"); + srcfile << "// A\n"; + addi(srcfile, "mov r9, [r14 + 8]"); + srcfile << "// B\n"; + addi(srcfile, "mov r10, [r14 + 16]"); + srcfile << "// beta\n"; + addi(srcfile, "mov r15, [r14 + 24]"); + srcfile << "// accum\n"; + addi(srcfile, "mov rdx, [r14 + 32]"); + srcfile << "// C\n"; + addi(srcfile, "mov r12, [r14 + 40]"); + srcfile << "// ldc\n"; + addi(srcfile, "mov r13, [r14 + 48]"); + srcfile << "// b_block_cols\n"; + addi(srcfile, "mov rdi, [r14 + 56]"); + srcfile << "// b_block_size\n"; + addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << "// Make copies of A and C\n"; + addi(srcfile, "mov rax, r9"); + addi(srcfile, "mov rcx, r12"); + srcfile << "\n\n"; + + addi(srcfile, "mov rbx, 0"); + + string exitlabel = "L_exit%="; + string label2 = "loop_outter%="; + addi(srcfile, label2 + ":"); + addi(srcfile, "mov r14, 0"); + + // set all vCtile regs to zeros + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + vCtile[r][c]); + } + } + + // start marker + if (iaca) { + addi(srcfile, "mov ebx, 111"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + srcfile << "\n"; + + if (ukernel_shape[k][0] <= 13) { + addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]"); + addi(srcfile, "mov r11, 16"); + } else { + addi(srcfile, "mov r11, 0"); + } + + srcfile << "\n"; + string label = "loop_inner%="; + addi(srcfile, label + ":"); + srcfile << "\n"; + + if (ukernel_shape[k][0] <= 13) { + auto a_offset = 0, unroll_factor = 2; + for (auto u = 0; u < unroll_factor; u++) { + string breg = (u == 0) ? "ymm14" : "ymm15"; + string breg_rev = (u == 0) ? "ymm15" : "ymm14"; + + addi(srcfile, "vcvtph2ps " + breg + + ",XMMWORD PTR [r10 + r11 + " + + to_string(u * 16) + "]"); + addi(srcfile, "inc r14"); + for (auto r = 0; r < vCtile.size(); r++) { + addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(a_offset) + "]"); + addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + + breg_rev + "," + vAtmp); + if (u == 1 && r == vCtile.size() / 2) + addi(srcfile, "add r11, 32"); + a_offset += 4; + } + if (u < unroll_factor - 1) { + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jge " + exitlabel); + } + } + + addi(srcfile, "add r9," + to_string(a_offset)); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); + } else { + addi(srcfile, + "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]"); + for (auto r = 0; r < vCtile.size(); r++) { + addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(4 * r) + "]"); + addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] + + "," + vAtmp); + } + + addi(srcfile, "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr + addi(srcfile, "add r11, 16"); + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + } + + addi(srcfile, "add r10, rsi"); + srcfile << "\n"; + + // end marker + if (iaca) { + addi(srcfile, "mov ebx, 222"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + + addi(srcfile, "cmp rdx, 1"); + addi(srcfile, "je L_accum%="); + srcfile << "// Dump C\n"; + + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi(srcfile, "vmovups YMMWORD PTR [r12 + " + + to_string(32 * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + addi(srcfile, "jmp L_done%="); + + srcfile << "\n\n"; + addi(srcfile, "L_accum%=:"); + srcfile << "// Dump C with accumulate\n"; + + string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; + addi(srcfile, + "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), + fixedC); + // store out C + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + switch (s.avx) { + case 1: + addi(srcfile, + string("vmulps ymm15, ") + r_spare + comma + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + addi(srcfile, "vaddps " + vCtile[r][c] + "," + + vCtile[r][c] + "," + "ymm15", + fixedC); + break; + case 2: + addi(srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + break; + default: + assert(0); + } + addi(srcfile, "vmovups YMMWORD PTR [r12 + " + + to_string(32 * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + + srcfile << "\n"; + addi(srcfile, "L_done%=:"); + + srcfile << "\n// next outer iteration\n"; + // C + addi(srcfile, "add rcx, " + to_string(32 * ukernel_shape[k][1]), + fixedC); + addi(srcfile, "mov r12, rcx", fixedC); + // A + addi(srcfile, "mov r9, rax"); + + addi(srcfile, "inc rbx"); + addi(srcfile, "cmp rbx, rdi"); + addi(srcfile, "jl " + label2); + + // output + srcfile << ":\n"; + // input + srcfile << ":\n"; + srcfile << "[gp] \"rm\" (gp)\n"; + + // clobbered + srcfile + << (string) ": \"r8\", \"r9\", \"r10\", \"r11\", \"r15\", " + + (string) " \"r13\", \"r14\",\n" + + (string) "\"rax\", \"rcx\", " + "\"rdx\", \"rsi\", \"rdi\", \"rbx\", " + "\"r12\", \"memory\"" + + (string) "\n"; + srcfile << ");\n"; + srcfile << "}\n"; + } + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + hdrfile << fheader[k] << ";\n"; + } + + fptr_typedef[B_type] = + "typedef void (* funcptr_" + B_type + ") " + fargs; + } + } + + srcfile.close(); + hdrfile << fptr_typedef["fp16"] << ";\n"; + hdrfile << fptr_typedef["fp32"] << ";\n"; + hdrfile << "#endif\n"; + hdrfile.close(); +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..9005dab --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.7 FATAL_ERROR) + +if(FBGEMM_BUILD_TESTS AND NOT TARGET gtest) + #Download Googletest framework from github if + #GOOGLETEST_SOURCE_DIR is not specified. + if(NOT DEFINED GOOGLETEST_SOURCE_DIR) + message(STATUS "Downloading googletest to + ${FBGEMM_THIRDPARTY_DIR}/googletest + (define GOOGLETEST_SOURCE_DIR to avoid it)") + configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadGTEST.cmake" + "${FBGEMM_BINARY_DIR}/googletest-download/CMakeLists.txt") + execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/googletest-download") + execute_process(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/googletest-download") + set(GOOGLETEST_SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/googletest" CACHE STRING + "googletest source directory") + endif() + + #build Googletest framework + add_subdirectory("${GOOGLETEST_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/googletest") +endif() + +macro(add_gtest TESTNAME) + add_executable(${TESTNAME} ${ARGN} + ../bench/BenchUtils.cc QuantizationHelpers.cc TestUtils.cc) + set_target_properties(${TESTNAME} PROPERTIES + CXX_STANDARD 11 + CXX_EXTENSIONS NO) + target_compile_options(${TESTNAME} PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel") + target_link_libraries(${TESTNAME} gtest gmock gtest_main fbgemm) + add_dependencies(${TESTNAME} gtest fbgemm) + add_test(${TESTNAME} ${TESTNAME}) + set_target_properties(${TESTNAME} PROPERTIES FOLDER test) +endmacro() + + +file(GLOB TEST_LIST "*Test.cc") + +foreach(TEST_FILE ${TEST_LIST}) + get_filename_component(TEST_NAME "${TEST_FILE}" NAME_WE) + get_filename_component(TEST_FILE_ONLY "${TEST_FILE}" NAME) + add_gtest("${TEST_NAME}" + "${TEST_FILE_ONLY}") +endforeach() + diff --git a/test/FP16Test.cc b/test/FP16Test.cc new file mode 100644 index 0000000..c346049 --- /dev/null +++ b/test/FP16Test.cc @@ -0,0 +1,124 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include + +#include "fbgemm/FbgemmFP16.h" +#include "src/RefImplementations.h" +#include "bench/BenchUtils.h" +#include "TestUtils.h" + +#ifdef USE_IACA +#include "iacaMarks.h" +#endif + +using namespace std; +using namespace fbgemm2; + +namespace { + // The template parameter is transpose of A and B + class FBGemmFP16Test : + public testing::TestWithParam> {}; +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmFP16Test, + ::testing::Values( + pair( + matrix_op_t::NoTranspose, matrix_op_t::NoTranspose), + pair( + matrix_op_t::NoTranspose, matrix_op_t::Transpose)/*, + pair( + matrix_op_t::Transpose, matrix_op_t::NoTranspose), + pair( + matrix_op_t::Transpose, matrix_op_t::Transpose)*/)); + +TEST_P(FBGemmFP16Test, Test) { + vector > shapes; + random_device r; + default_random_engine generator(r()); + uniform_int_distribution dm(1,100); + uniform_int_distribution dnk(1,1024); + for (int i = 0; i < 10; i++) { + int m = dm(generator); + int n = dnk(generator); + int k = dnk(generator); + shapes.push_back({m, n, k}); + if (m > 10) { + shapes.push_back({(m / 10) * 10, n, k}); + } + } + + float alpha = 1.f, beta = 0.f; + matrix_op_t atrans, btrans; + tie(atrans, btrans) = GetParam(); + + for (auto s : shapes) { + int m = s[0]; + int n = s[1]; + int k = s[2]; + + cerr << "m = " << m << " n = " << n << " k = " << k; + if (atrans == matrix_op_t::Transpose) { + cerr << " A_transposed"; + } + if (btrans == matrix_op_t::Transpose) { + cerr << " B_transposed"; + } + cerr << endl; + + aligned_vector A(m * k, 0.f); + aligned_vector B(k * n, 0.f); + aligned_vector C(m * n, 0.f); + + // initialize with small numbers + randFill(A, 0, 4); + randFill(B, 0, 4); + randFill(C, 0, 4); + + aligned_vector A_ref, B_ref, C_ref; + A_ref = A; + B_ref = B; + C_ref = C; + + if (atrans == matrix_op_t::Transpose) { + transpose_matrix(A_ref.data(), k, m); + } + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(B_ref.data(), n, k); + } + + // Gold via reference sgemm + matmul_fp_ref( + m, + n, + k, + k, + n, + n, + A_ref.data(), + B_ref.data(), + C_ref.data()); + + // fbgemm fp16 + PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); + cblas_gemm_compute(atrans, m, A.data(), Bp, beta, C.data()); + + // correctness check + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float expected = C_ref[i * n + j]; + float actual = C[i * n + j]; + EXPECT_EQ(expected, actual) << + "GEMM results differ at (" << i << ", " << j << + "). ref " << expected << " FBGemm " << actual; + } + } + } +} diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc new file mode 100644 index 0000000..cfde880 --- /dev/null +++ b/test/I8DepthwiseTest.cc @@ -0,0 +1,448 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "I8DepthwiseTest.h" + +#include +#include + +#include + +#include "bench/AlignedVec.h" +#include "src/FbgemmI8Depthwise.h" +#include "src/RefImplementations.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" + +using namespace std; + +namespace fbgemm2 +{ + +// From Xray OCR +static vector> shapes = { + // N, K, H_in, W_in, stride + { 1, 272, 47, 125, 1, }, +// { 1, 272, 64, 125, 1, }, +// { 1, 272, 66, 125, 1, }, +// { 1, 272, 67, 100, 1, }, +// { 1, 272, 75, 75, 1, }, + { 1, 272, 75, 76, 1, }, +// { 1, 272, 75, 100, 1, }, +// { 1, 272, 94, 75, 1, }, +// { 1, 272, 109, 75, 1, }, + { 1, 544, 24, 63, 1, }, +// { 1, 544, 33, 63, 1, }, +// { 1, 544, 34, 50, 1, }, +// { 1, 544, 36, 63, 1, }, +// { 1, 544, 38, 38, 1, }, +// { 1, 544, 38, 40, 1, }, + { 1, 544, 47, 38, 1, }, + { 1, 1088, 7, 7, 1, }, + { 51, 1088, 7, 7, 1, }, +// { 100, 1088, 7, 7, 1, }, + + { 1, 248, 93, 250, 2, }, +// { 1, 248, 128, 250, 2, }, +// { 1, 248, 133, 200, 2, }, +// { 1, 248, 150, 150, 2, }, + { 1, 248, 150, 151, 2, }, +// { 1, 248, 150, 158, 2, }, +// { 1, 248, 188, 150, 2, }, +// { 1, 248, 225, 150, 2, }, + { 1, 272, 47, 125, 2, }, +// { 1, 272, 64, 125, 2, }, +// { 1, 272, 66, 125, 2, }, +// { 1, 272, 67, 100, 2, }, +// { 1, 272, 75, 75, 2, }, +// { 1, 272, 75, 76, 2, }, + { 1, 272, 94, 75, 2, }, + { 1, 544, 14, 14, 2, }, + { 51, 544, 14, 14, 2, }, +// { 100, 544, 14, 14, 2, }, + + { 1, 8, 4, 4, 1, }, +}; + +TEST(FBGemmDepthWiseTest, Test3x3) { + for (auto shape : shapes) { + int N = shape[0]; + int K = shape[1]; + int H = shape[2]; + int W = shape[3]; + int stride_h = shape[4]; + int stride_w = stride_h; + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + aligned_vector A(N * H * W * K); + aligned_vector B(K * R * S); + aligned_vector C_ref(N * H_OUT * W_OUT * K), C(C_ref.size()); + + randFill(A, 0, 86); + int32_t A_zero_point = 43; + + randFill(B, -16, 16); + int32_t B_zero_point = 5; + + depthwise_3x3_pad_1_ref( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); + int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + + float C_multiplier = 255. / (maximum - minimum); + + aligned_vector col_offsets(K); + aligned_vector bias(K); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + int32_t C_zero_point = 5; + + aligned_vector C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3_pad_1_ref( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + B.data(), + C_multiplier, + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3ConvMatrix Bp(K, B.data()); + + depthwise_3x3_pad_1( + N, H, W, K, stride_h, stride_w, A_zero_point, A.data(), Bp, C.data()); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = C_ref[((n * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = C[((n * H_OUT + h) * W_OUT + w) * K + k]; + EXPECT_EQ(expected, actual) << + "Depthwise 3x3 results differ at (" << n << ", " << + h << ", " << w << ", " << k << ")."; + } + } + } + } + + depthwise_3x3_pad_1( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_uint8.data(), + col_offsets.data(), + bias.data(), + 0, + 1); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = + C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + k]; + EXPECT_EQ(expected, actual) << + "Depthwise 3x3 results differ at (" << n << ", " << + h << ", " << w << ", " << k << ")."; + } + } + } + } + } // for each shape +} // Test3x3 + +TEST(FBGemmDepthWiseTest, Test3x3x3) { + for (auto shape : shapes_3d) { + int N = shape[0]; + int K = shape[1]; + int T = shape[2]; + int H = shape[3]; + int W = shape[4]; + int stride_t = shape[5]; + int stride_h = stride_t; + int stride_w = stride_t; + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + aligned_vector A(N * T * H * W * K); + aligned_vector B(K * K_T * K_H * K_W); + aligned_vector C_ref(N * T_OUT * H_OUT * W_OUT * K), + C(C_ref.size()); + + randFill(A, 0, 86); + int32_t A_zero_point = 43; + + randFill(B, -16, 16); + int32_t B_zero_point = 5; + + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); + int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + + float C_multiplier = 255. / (maximum - minimum); + + aligned_vector col_offsets(K); + aligned_vector bias(K); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + int32_t C_zero_point = 5; + + aligned_vector C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point, + B.data(), + C_multiplier, + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3x3ConvMatrix Bp(K, B.data()); + + depthwise_3x3x3_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + Bp, + C.data()); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = + C_ref[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = + C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + ASSERT_EQ(expected, actual) + << "Depthwise 3x3 results differ at (" << n << ", " << t + << ", " << h << ", " << w << ", " << k << ") " + << shape[0] << " " << shape[1] << " " << shape[2] << " " + << shape[3] << " " << shape[4] << " " << shape[5]; + } + } // w + } // h + } // t + } // n + + depthwise_3x3x3_pad_1( + N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A.data(), + B_zero_point, Bp, C_multiplier, C_zero_point, + C_uint8.data(), col_offsets.data(), bias.data(), + false /* fuse_relu */, 0, 1); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = C_uint8_ref + [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = + C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + EXPECT_EQ(expected, actual) + << "Depthwise 3x3 results differ at (" << n << ", " << t + << ", " << h << ", " << w << ", " << k << ")."; + } + } // w + } // h + } // t + } // n + } // for each shape +} // Test3x3x3 + +TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { + for (auto shape : shapes) { + int N = shape[0]; + int K = shape[1]; + int H = shape[2]; + int W = shape[3]; + int stride_h = shape[4]; + int stride_w = stride_h; + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + aligned_vector A(N * H * W * K); + aligned_vector B(K * R * S); + int32_t C_num_rows = N * H_OUT * W_OUT; + aligned_vector C_ref(C_num_rows * K), C(C_ref.size()); + + randFill(A, 0, 86); + int32_t A_zero_point = 43; + + // Each row of G has a different range to really test per-channel + // quantization. + vector B_zero_point(K); + for (auto k = 0; k < K; ++k) { + aligned_vector Bk(R * S); + randFill(Bk, -16 + k, 16 + k); + copy(Bk.begin(), Bk.end(), B.begin() + k * R * S); + + B_zero_point[k] = 5 + k; + } + + depthwise_3x3_pad_1_ref( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + aligned_vector C_ref_transpose(C_ref); + transpose_matrix(C_ref.data(), C_num_rows, K); + vector C_multiplier(K); + for (auto k = 0; k < K; ++k) { + auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; + auto C_ref_k_end = C_ref_k_begin + C_num_rows; + int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); + int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); + C_multiplier[k] = 255. / (maximum - minimum); + cerr << "k " << k << " minimum " << minimum << " maximum " << maximum + << " multiplier " << C_multiplier[k] << endl; + } + int32_t C_zero_point = 5; + + aligned_vector col_offsets(K); + aligned_vector bias(K); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + + aligned_vector C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3_per_channel_quantization_pad_1_ref( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point.data(), + B.data(), + C_multiplier.data(), + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3ConvMatrix Bp(K, B.data()); + + depthwise_3x3_per_channel_quantization_pad_1( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point.data(), + Bp, + C_multiplier.data(), + C_zero_point, + C_uint8.data(), + col_offsets.data(), + bias.data(), + 0, + 1); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = + C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + k]; + EXPECT_EQ(expected, actual) << + "Depthwise 3x3 results differ at (" << n << ", " << + h << ", " << w << ", " << k << ")."; + } + } + } + } + } // for each shape +} // Test3x3PerChannelQuantization + +} // namespace fbgemm2 diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h new file mode 100644 index 0000000..38dea6f --- /dev/null +++ b/test/I8DepthwiseTest.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +namespace fbgemm2 +{ + +// From ResNeXt-3D-101 +static std::vector> shapes_3d = { + // N, K, T_in, H_in, W_in, stride + { 1, 64, 32, 56, 56, 1, }, + { 1, 128, 16, 28, 28, 1, }, + { 1, 256, 8, 14, 14, 1, }, + { 1, 512, 4, 7, 7, 1, }, + + { 1, 128, 32, 56, 56, 2, }, + { 1, 256, 16, 28, 28, 2, }, + { 1, 512, 8, 14, 14, 2, }, + + { 5, 64, 32, 56, 56, 1, }, + { 5, 128, 16, 28, 28, 1, }, + { 5, 256, 8, 14, 14, 1, }, + { 5, 512, 4, 7, 7, 1, }, + + { 5, 128, 32, 56, 56, 2, }, + { 5, 256, 16, 28, 28, 2, }, + { 5, 512, 8, 14, 14, 2, }, + + { 1, 8, 4, 4, 4, 1, }, +}; + +} // namespace fbgemm2 diff --git a/test/I8SpmdmTest.cc b/test/I8SpmdmTest.cc new file mode 100644 index 0000000..c74c98a --- /dev/null +++ b/test/I8SpmdmTest.cc @@ -0,0 +1,158 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#ifdef _OPENMP +#include +#endif + +#include "fbgemm/FbgemmI8Spmdm.h" +#include "src/RefImplementations.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" + +using namespace std; +using namespace fbgemm2; + +std::vector densities{0.0001f, 0.001f, 0.01f, 0.1f, 1.0f}; + +namespace { +class fbgemmSPMDMTest + : public testing::TestWithParam> {}; +} // namespace + +INSTANTIATE_TEST_CASE_P( + Instance0, + fbgemmSPMDMTest, + ::testing::Combine( + ::testing::ValuesIn(densities), + ::testing::Bool(), + ::testing::Bool())); + +TEST_P(fbgemmSPMDMTest, TestsSpMDM) { + const vector> shapes = { + // M, N, K + { 1024, 1024, 1024 }, + { 511, 512, 512 }, + { 111, 111, 111 }, + { 14*14*2, 4, 2 }, + }; + + float density; + bool accumulation, test_ld; + tie(density, accumulation, test_ld) = GetParam(); + + for (const auto& shape : shapes) { + int M = shape[0]; + int N = shape[1]; + int K = shape[2]; + int N_adjusted = N; + int K_adjusted = K; + if (test_ld) { + // When test_ld is true, we multiply with the bottom-right quadrant of B + N_adjusted = std::max(N / 2, 1); + K_adjusted = std::max(K / 2, 1); + } + + aligned_vector A(M * K); + randFill(A, 0, 255); + + CompressedSparseColumn B_csc(K_adjusted, N_adjusted); + vector C(M * N); + vector C_ref(C.size()); + + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + C_ref[i * N + j] = i + j; + } + } + + // deterministic random number + default_random_engine eng; + binomial_distribution<> per_col_nnz_dist(K_adjusted, density); + uniform_int_distribution<> value_dist( + numeric_limits::min() / 2, + numeric_limits::max() / 2); + + vector row_indices(K_adjusted); + + int total_nnz = 0; + for (int j = 0; j < N_adjusted; ++j) { + B_csc.ColPtr()[j] = total_nnz; + + int nnz_of_j = per_col_nnz_dist(eng); + total_nnz += nnz_of_j; + + iota(row_indices.begin(), row_indices.end(), 0); + shuffle(row_indices.begin(), row_indices.end(), eng); + sort(row_indices.begin(), row_indices.begin() + nnz_of_j); + + for (int k = 0; k < nnz_of_j; ++k) { + B_csc.RowIdx().push_back(row_indices[k]); + B_csc.Values().push_back(value_dist(eng)); + } + } + B_csc.ColPtr()[N_adjusted] = total_nnz; + + spmdm_ref( + M, + A.data() + (test_ld ? K_adjusted : 0), + K, + B_csc, + accumulation, + C_ref.data() + (test_ld ? N_adjusted : 0), + N); + + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + if (accumulation) { + C[i * N + j] = i + j; + } else { + C[i * N + j] = i + j + 1; + } + } + } + +#pragma omp parallel + { +#ifdef _OPENMP + int num_threads = omp_get_num_threads(); + int tid = omp_get_thread_num(); +#else + int num_threads = 1; + int tid = 0; +#endif + int i_per_thread = (M + num_threads - 1) / num_threads; + int i_begin = std::min(tid * i_per_thread, M); + int i_end = std::min(i_begin + i_per_thread, M); + + block_type_t block = {i_begin, i_end - i_begin, 0, N_adjusted}; + B_csc.SpMDM( + block, + A.data() + (test_ld ? K_adjusted : 0), + K, + accumulation, + C.data() + i_begin * N + (test_ld ? N_adjusted : 0), + N); + } + + compare_validate_buffers( + C_ref.data() + (test_ld ? N_adjusted : 0), + C.data() + (test_ld ? N_adjusted : 0), + M, + N_adjusted, + N, + static_cast(0)); + } // for each shape +} diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc new file mode 100644 index 0000000..28e1114 --- /dev/null +++ b/test/PackedRequantizeAcc16Test.cc @@ -0,0 +1,535 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include + +#include "fbgemm/Fbgemm.h" +#include "bench/BenchUtils.h" +#include "src/RefImplementations.h" +#include "QuantizationHelpers.h" +#include "TestUtils.h" + +using namespace std; +using namespace fbgemm2; + +std::vector transposeVals{matrix_op_t::NoTranspose, + matrix_op_t::Transpose}; + +namespace { +class fbgemmu8s8acc16test : public testing::TestWithParam< + std::tuple> {}; +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmu8s8acc16test, + ::testing::Combine( + ::testing::Values(matrix_op_t::NoTranspose), + ::testing::ValuesIn(transposeVals), + ::testing::Bool())); + +/** + * @brief Shapes for unit test. + */ +static vector> GetShapes_() { + // NMT + vector> shapes = { + // {M, N, K} + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 2048, 513}, + {1, 2048, 514}, + + {6, 512, 512}, + {6, 2048, 512}, + {6, 256, 1024}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 257}, + {6, 2048, 258}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + {102, 512, 257}, + {102, 512, 258}, + }; + return shapes; +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(fbgemmu8s8acc16test, Test) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k, 0); + aligned_vector Bint8(k * n, 0); + aligned_vector Bint8_ref(k * n, 0); + aligned_vector Cint32_local(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint8_local(m * n, 0); + + randFill(Aint8, 0, 255); + int32_t Aint8_zero_point = 43; + + randFill(Bint8_ref, -128, 127); + + for (auto i = 0; i < Bint8.size(); ++i) { + Bint8[i] = Bint8_ref[i]; + } + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + int32_t Bint8_zero_point = -30; + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int k_adjusted = k; + int n_adjusted = n; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8_ref.data(), + Bint8_zero_point, + col_offsets.data()); + + vector row_offsets; + row_offsets.resize(m); + + float C_multiplier = 0.1234; + int32_t C_zero_pt = 5; + + int brow = 256; + matmul_u8i8acc16_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + brow, + Aint8.data(), + Bint8_ref.data(), + Cint32_local.data()); + + row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data()); + + requantize_u8acc32_ref( + m, + n_adjusted, + n, + Cint32_local.data(), + Cint8_local.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + nullptr); + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + fbgemmPacked( + packAN, + packedBN, + Cint8_fb.data(), + Cint32_buffer.data(), + n, + outputProcObj, + 0, + 1); + + compare_validate_buffers( + Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast(0)); + } +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit + * accumulation. Output processing: spmdm -> requantization -> nothing + */ +TEST_P(fbgemmu8s8acc16test, SpMDMTest) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k, 0); + aligned_vector Bint8(k * n, 0); + aligned_vector Bint8_ref(k * n, 0); + aligned_vector Cint32_local(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint8_local(m * n, 0); + + randFill(Aint8, 0, 255); + int32_t Aint8_zero_point = 43; + + randFill(Bint8, -128, 127); + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int k_adjusted = k; + int n_adjusted = n; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + int32_t Bint8_zero_point = -30; + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8.data(), + Bint8_zero_point, + col_offsets.data()); + + CompressedSparseColumn B_csc(k_adjusted, n_adjusted); + float density = 0.001f; + // deterministic random number + default_random_engine eng; + binomial_distribution<> per_col_nnz_dist(k_adjusted, density); + uniform_int_distribution<> value_dist( + numeric_limits::min() / 2, numeric_limits::max() / 2); + + vector row_indices(k_adjusted); + int total_nnz = 0; + for (int j = 0; j < n_adjusted; ++j) { + B_csc.ColPtr()[j] = total_nnz; + + int nnz_of_j = per_col_nnz_dist(eng); + total_nnz += nnz_of_j; + + iota(row_indices.begin(), row_indices.end(), 0); + shuffle(row_indices.begin(), row_indices.end(), eng); + sort(row_indices.begin(), row_indices.begin() + nnz_of_j); + + for (int kidx = 0; kidx < nnz_of_j; ++kidx) { + B_csc.RowIdx().push_back(row_indices[kidx]); + // put the current B value + B_csc.Values().push_back(Bint8[row_indices[kidx] * n + j]); + // make current B value zero + Bint8[row_indices[kidx] * n + j] = 0; + } + } + B_csc.ColPtr()[n_adjusted] = total_nnz; + + for (auto i = 0; i < Bint8.size(); ++i) { + Bint8_ref[i] = Bint8[i]; + } + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + vector row_offsets; + row_offsets.resize(m); + + float C_multiplier = 0.1234; + int32_t C_zero_pt = 5; + + int brow = 256; + matmul_u8i8acc16_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + brow, + Aint8.data(), + Bint8_ref.data(), + Cint32_local.data()); + + bool accumulation = true; + spmdm_ref(m, Aint8.data(), k, B_csc, accumulation, Cint32_local.data(), n); + + row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data()); + + requantize_u8acc32_ref( + m, + n_adjusted, + n, + Cint32_local.data(), + Cint8_local.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + nullptr); + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + // spmdm -> requantization -> nothing + // construct an output processing pipeline in reverse order + // i.e. last output operation first + // Last operation should always be DoNothing with + // correct input and output type. + DoNothing<> doNothingObj{}; + // The second last operation is requantization back + // to int8 + ReQuantizeOutput reqObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + // the top most (first) operation in the output processing + // pipeline is spmdm + // outType = final output type after fullly processing through pipeline + // inType = initial input type at the first call to the whole pipeline + DoSpmdmOnInpBuffer< + ReQuantizeOutput::outType, + int32_t, + ReQuantizeOutput> + spmdmObj(reqObj, Aint8.data(), k, B_csc); + + PackBMatrix packedB( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n); + + fbgemmPacked( + packAN, packedB, Cint8_fb.data(), Cint32_fb.data(), n, spmdmObj, 0, 1); + + compare_validate_buffers( + Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast(0)); + } +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 16-bit + * accumulation. Output processing: nothing + */ +TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k, 0); + aligned_vector Bint8(k * n, 0); + aligned_vector Bint8_ref(k * n, 0); + aligned_vector Cint32_local(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint8_local(m * n, 0); + + randFill(Aint8, 0, 255); + int32_t Aint8_zero_point = 43; + + randFill(Bint8_ref, -128, 127); + + for (auto i = 0; i < Bint8.size(); ++i) { + Bint8[i] = Bint8_ref[i]; + } + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + int32_t Bint8_zero_point = -30; + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int k_adjusted = k; + int n_adjusted = n; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8_ref.data(), + Bint8_zero_point, + col_offsets.data()); + + vector row_offsets; + row_offsets.resize(m); + + int brow = 256; + matmul_u8i8acc16_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + brow, + Aint8.data(), + Bint8_ref.data(), + Cint32_local.data()); + + row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data()); + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n, + nullptr, + 1, + Bint8_zero_point); + + // DoNothing<> doNothingObj{}; + DoNothing doNothingObj{}; + memCopy<> outputProcObj(doNothingObj); + fbgemmPacked( + packAN, + packedBN, + Cint32_fb.data(), + Cint32_buffer.data(), + n, + outputProcObj, + 0, + 1); + + compare_validate_buffers( + Cint32_local.data(), + Cint32_fb.data(), + m, + n, + n, + static_cast(0)); + } +} diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc new file mode 100644 index 0000000..4b1b5b8 --- /dev/null +++ b/test/PackedRequantizeTest.cc @@ -0,0 +1,625 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include + +#include "fbgemm/Fbgemm.h" +#include "bench/BenchUtils.h" +#include "src/RefImplementations.h" +#include "QuantizationHelpers.h" +#include "TestUtils.h" + +using namespace std; +using namespace fbgemm2; + +std::vector transposeVals{matrix_op_t::NoTranspose, + matrix_op_t::Transpose}; + +namespace { +class fbgemmu8s8acc32test : public testing::TestWithParam< + std::tuple> {}; +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmu8s8acc32test, + ::testing::Combine( + ::testing::Values(matrix_op_t::NoTranspose), + ::testing::ValuesIn(transposeVals), + ::testing::Bool())); + +/** + * @brief Shapes for unit test. + */ +static vector> GetShapes_() { + // NMT + vector> shapes = { + // {M, N, K} + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 2048, 513}, + {1, 2048, 514}, + + {6, 512, 512}, + {6, 2048, 512}, + {6, 256, 1024}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 257}, + {6, 2048, 258}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + {102, 512, 257}, + {102, 512, 258}, + }; + return shapes; +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(fbgemmu8s8acc32test, Test) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k, 0); + + // nxk matrix + aligned_vector Bint8(k * n, 0); + // kxn matrix + aligned_vector Bint8_ref(k * n, 0); + + aligned_vector Cint32_ref(m * n, 0.0f); + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint32_local(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + aligned_vector Cint8_local(m * n, 0); + + randFill(Aint8, 0, 255); + int32_t Aint8_zero_point = 43; + + randFill(Bint8_ref, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bint8_ref.data()); + + for (auto i = 0; i < Bint8.size(); ++i) { + Bint8[i] = Bint8_ref[i]; + } + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + int32_t Bint8_zero_point = -30; + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int k_adjusted = k; + int n_adjusted = n; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8_ref.data(), + Bint8_zero_point, + col_offsets.data()); + + vector row_offsets; + row_offsets.resize(m); + + float C_multiplier = 0.1234; + int32_t C_zero_pt = 5; + + matmul_u8i8acc32_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + Aint8.data(), + Bint8_ref.data(), + Cint32_local.data()); + + row_offsets_u8acc32_ref(m, k_adjusted, k, Aint8.data(), row_offsets.data()); + + requantize_u8acc32_ref( + m, + n_adjusted, + n, + Cint32_local.data(), + Cint8_local.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + nullptr); + + vector row_offset_buf; + row_offset_buf.resize(PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + fbgemmPacked( + packAN, + packedBN, + Cint8_fb.data(), + Cint32_buffer.data(), + n, + outputProcObj, + 0, + 1); + // printMatrix(matrix_op_t::NoTranspose, Cint32_local.data(), + // m, n_adjusted, n, "C local"); + compare_validate_buffers( + Cint8_local.data(), Cint8_fb.data(), m, n, n, static_cast(0)); + } +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit + * accumulation. Directly output fp32 matrix C. Output processing: + * requantization -> nothing + */ +TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Afp32(m * k, 0.0f); + aligned_vector Aint8(m * k, 0); + + aligned_vector Bfp32(k * n, 0.0f); + aligned_vector Bint8(k * n, 0); + + aligned_vector Cfp32_ref(m * n, 0.0f); + aligned_vector Cfp32_fb(m * n, 0.0f); + + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint32_buffer(m * n, 0); + + randFill(Aint8, 0, 255); + int32_t Aint8_zero_point = 43; + float Aint8_scale = 0.11; + for (auto i = 0; i < Afp32.size(); ++i) { + Afp32[i] = Aint8_scale * (Aint8[i] - Aint8_zero_point); + } + + randFill(Bint8, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bint8.data()); + int32_t Bint8_zero_point = -30; + float Bint8_scale = 0.49; + for (auto i = 0; i < Bfp32.size(); ++i) { + Bfp32[i] = Bint8_scale * (Bint8[i] - Bint8_zero_point); + } + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int k_adjusted = k; + int n_adjusted = n; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8.data(), + Bint8_zero_point, + col_offsets.data()); + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + matmul_fp_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + Afp32.data(), + Bfp32.data(), + Cfp32_ref.data()); + + vector row_offset_buf; + row_offset_buf.resize( + PackAWithQuantRowOffset::rowOffsetBufferSize()); + + PackAWithQuantRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Afp32.data(), + k, + nullptr, /*buffer for packed matrix*/ + Aint8_scale, + Aint8_zero_point, + 1, /*groups*/ + row_offset_buf.data()); + + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing doNothingObj{}; + ReQuantizeForFloat outputProcObj( + doNothingObj, + Aint8_scale, + Bint8_scale, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + nullptr); + + fbgemmPacked( + packAN, + packedBN, + Cfp32_fb.data(), + (int32_t*)Cfp32_fb.data(), + n, + outputProcObj, + 0, + 1); + + float maximum = *max_element(Cfp32_ref.begin(), Cfp32_ref.end()); + float minimum = *min_element(Cfp32_ref.begin(), Cfp32_ref.end()); + float atol = (maximum - minimum) / 255 / 1.9; + + compare_validate_buffers(Cfp32_ref.data(), Cfp32_fb.data(), m, n, n, atol); + } +} + +/** + * @brief Unit test for uint8 matrix A, int8 matrix B, and 32-bit + * accumulation. Output processing: requantization -> nothing. Symmetric: the + * zero point is 0. + */ +TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Afp32(m * k, 0.0f); + aligned_vector Aint8(m * k, 0); + + aligned_vector Bfp32(k * n, 0.0f); + aligned_vector Bint8(k * n, 0); + + aligned_vector Cfp32_ref(m * n, 0.0f); + aligned_vector Cint32_fb(m * n, 0); + + randFill(Afp32, 0, 255); + for (auto i = 0; i < Afp32.size(); i++) { + Aint8[i] = (uint8_t)Afp32[i]; + } + + // initialize B matrix + randFill(Bfp32, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bfp32.data()); + + for (auto i = 0; i < Bfp32.size(); ++i) { + Bint8[i] = (int8_t)Bfp32[i]; + } + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int m_adjusted = m; + int n_adjusted = n; + int k_adjusted = k; + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + matmul_fp_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + Afp32.data(), + Bfp32.data(), + Cfp32_ref.data()); + + DoNothing doNothingObj{}; + memCopy<> outputProcObj(doNothingObj); + // A zero point and row offset not required + PackAMatrix packAN( + matrix_op_t::NoTranspose, m, k_adjusted, Aint8.data(), k); + + // B zero point defaults to 0 + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n); + + fbgemmPacked( + packAN, + packedBN, + Cint32_fb.data(), + Cint32_fb.data(), + n, + outputProcObj, + 0, + 1); + + // correctness check + for (int i = 0; i < m_adjusted; ++i) { + for (int j = 0; j < n_adjusted; ++j) { + float expected = Cfp32_ref[i * n + j]; + int32_t actual = Cint32_fb[i * n + j]; + EXPECT_EQ(expected, actual) + << "GEMM results differ at (" << i << ", " << j << "). ref " + << expected << " FBGemm " << actual; + } + } + } +} + +/** + * @brief Unit test for unt8 matrix A, int8 matrix B, and 32-bit + * accumulation. Output processing: requantization with bias -> nothing. + * Asymmetric: the zero point is not 0. + */ +TEST_P(fbgemmu8s8acc32test, TestAsymmetricQuantizedWithBias) { + vector> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + bool test_ld; + tie(atrans, btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k, 0); + aligned_vector Aint8_ref(m * k, 0); + + aligned_vector Bint8(k * n, 0); + aligned_vector Bint8_ref(k * n, 0); + + aligned_vector Cint32_fb(m * n, 0); + aligned_vector Cint32_ref(m * n, 0); + + aligned_vector Cint8_fb(m * n, 0); + aligned_vector Cint8_ref(m * n, 0); + + int n_adjusted = n; + int k_adjusted = k; + + if (test_ld) { + assert( + atrans == matrix_op_t::NoTranspose && "This case is not handled yet"); + k_adjusted = std::max(k / 2, 1); + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // A and B have scale 1, so exactly represented after quantization + randFill(Aint8, 0, 255); + randFill(Bint8, -128, 127); + avoidOverflow(m, n, k, Aint8.data(), Bint8.data()); + + for (auto i = 0; i < Bint8.size(); ++i) { + Bint8_ref[i] = Bint8[i]; + } + + for (auto i = 0; i < Aint8.size(); ++i) { + Aint8_ref[i] = Aint8[i]; + } + + int32_t Aint8_zero_point = 55; + int32_t Bint8_zero_point = -17; + + // initialize bias + aligned_vector bias_int32(n); + randFill(bias_int32, -128, 127); + + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(Bint8.data(), k, n); + } + + // computing column offset + vector col_offsets; + col_offsets.resize(n_adjusted); + col_offsets_with_zero_pt_s8acc32_ref( + k_adjusted, + n_adjusted, + n, + Bint8_ref.data(), + Bint8_zero_point, + col_offsets.data()); + + matmul_u8i8acc32_ref( + m, + n_adjusted, + k_adjusted, + k, + n, + n, + Aint8.data(), + Bint8_ref.data(), + Cint32_ref.data()); + + vector row_offsets; + row_offsets.resize(m); + + row_offsets_u8acc32_ref( + m, k_adjusted, k, Aint8_ref.data(), row_offsets.data()); + + float C_multiplier = 0.1234; + int32_t C_zero_pt = 5; + + requantize_u8acc32_ref( + m, + n_adjusted, + n, + Cint32_ref.data(), + Cint8_ref.data(), + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + row_offsets.data(), + col_offsets.data(), + bias_int32.data()); + + vector row_offset_buf; + row_offset_buf.resize(PackAWithRowOffset::rowOffsetBufferSize()); + + PackAWithRowOffset packAN( + matrix_op_t::NoTranspose, + m, + k_adjusted, + Aint8.data(), + k, + nullptr, + 1, + Aint8_zero_point, + row_offset_buf.data()); + + PackBMatrix packedBN( + btrans, + k_adjusted, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k : n, + nullptr, + 1, + Bint8_zero_point); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point, + packAN.getRowOffsetBuffer(), + col_offsets.data(), + bias_int32.data()); + + fbgemmPacked( + packAN, + packedBN, + Cint8_fb.data(), + Cint32_fb.data(), + n, + outputProcObj, + 0, + 1); + + compare_validate_buffers( + Cint8_fb.data(), Cint8_ref.data(), m, n, n, static_cast(0)); + } +} diff --git a/test/QuantizationHelpers.cc b/test/QuantizationHelpers.cc new file mode 100644 index 0000000..354519b --- /dev/null +++ b/test/QuantizationHelpers.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "QuantizationHelpers.h" +#include +#include +#include +#include + +using namespace std; + +namespace fbgemm2 { +/* + * @brief Make sure we won't have overflows from vpmaddubsw instruction. + */ +template +void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int kk = 0; kk < k / 2 * 2; kk += 2) { + int a0 = Aint8[i * k + kk], a1 = Aint8[i * k + kk + 1]; + int b0 = B[kk * n + j], b1 = B[(kk + 1) * n + j]; + int sum_pair = a0 * b0 + a1 * b1; + if (sum_pair < numeric_limits::lowest()) { + int b1_adjusted = + ceil((numeric_limits::lowest() - a0 * b0) / a1); + b1_adjusted = std::min(std::max(b1_adjusted, -128), 127); + + int new_sum_pair = a0 * b0 + a1 * b1_adjusted; + assert( + new_sum_pair >= numeric_limits::lowest() && + new_sum_pair <= numeric_limits::max()); + B[(kk + 1) * n + j] = b1_adjusted; + } else if (sum_pair > numeric_limits::max()) { + int b1_adjusted = + floor((numeric_limits::max() - a0 * b0) / a1); + b1_adjusted = std::min(std::max(b1_adjusted, -128), 127); + + int new_sum_pair = a0 * b0 + a1 * b1_adjusted; + assert( + new_sum_pair >= numeric_limits::lowest() && + new_sum_pair <= numeric_limits::max()); + B[(kk + 1) * n + j] = b1_adjusted; + } + } + } // for each j + } // for each i +} + +template void +avoidOverflow(int m, int n, int k, const uint8_t* Aint8, int8_t* B); +template void +avoidOverflow(int m, int n, int k, const uint8_t* Aint8, float* B); +} // namespace fbgemm2 diff --git a/test/QuantizationHelpers.h b/test/QuantizationHelpers.h new file mode 100644 index 0000000..fdc6e02 --- /dev/null +++ b/test/QuantizationHelpers.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include + +namespace fbgemm2 { + +/* + * @brief Make sure we won't have overflows from vpmaddubsw instruction. + */ +template +void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B); + +} // namespace fbgemm2 diff --git a/test/TestUtils.cc b/test/TestUtils.cc new file mode 100644 index 0000000..702425e --- /dev/null +++ b/test/TestUtils.cc @@ -0,0 +1,100 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "TestUtils.h" +#include +#include "fbgemm/Fbgemm.h" +#include "bench/AlignedVec.h" + +namespace fbgemm2 { + +template +int compare_validate_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + T atol) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + if (std::is_integral::value) { + EXPECT_EQ(ref[i * ld + j], test[i * ld + j]) + << "GEMM results differ at (" << i << ", " << j + << ") reference: " << (int64_t)ref[i * ld + j] + << ", FBGEMM: " << (int64_t)test[i * ld + j]; + } else { + EXPECT_LE(std::abs(ref[i * ld + j] - test[i * ld + j]), atol) + << "GEMM results differ at (" << i << ", " << j + << ") reference: " << ref[i * ld + j] + << ", FBGEMM: " << test[i * ld + j]; + } + } + } + return 0; +} + +template int compare_validate_buffers( + const float* ref, + const float* test, + int m, + int n, + int ld, + float atol); + +template int compare_validate_buffers( + const int32_t* ref, + const int32_t* test, + int m, + int n, + int ld, + int32_t atol); + +template int compare_validate_buffers( + const uint8_t* ref, + const uint8_t* test, + int m, + int n, + int ld, + uint8_t atol); + +template +bool check_all_zero_entries(const T* test, int m, int n) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + if (test[i * n + j] != 0) + return true; + } + } + return false; +} + +template bool check_all_zero_entries(const float* test, int m, int n); +template bool +check_all_zero_entries(const int32_t* test, int m, int n); +template bool +check_all_zero_entries(const uint8_t* test, int m, int n); + +template +void transpose_matrix(T* ref, int n, int k) { + aligned_vector local(n * k, 0); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < k; ++j) { + local[j * n + i] = ref[i * k + j]; + } + } + for (int i = 0; i < k; ++i) { + for (int j = 0; j < n; ++j) { + ref[i * n + j] = local[i * n + j]; + } + } +} + +template void transpose_matrix(float* ref, int n, int k); +template void transpose_matrix(int32_t* ref, int n, int k); +template void transpose_matrix(uint8_t* ref, int n, int k); +template void transpose_matrix(int8_t* ref, int n, int k); +} // namespace fbgemm2 diff --git a/test/TestUtils.h b/test/TestUtils.h new file mode 100644 index 0000000..559f816 --- /dev/null +++ b/test/TestUtils.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include + +namespace fbgemm2 { + +/* + * @brief Check and validate the buffers for reference and FBGEMM result. + */ +template +int compare_validate_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + T atol); + +/* + * @brief Check if all entries are zero or not. + * If any entry is non-zero, return True; + * otherwise, return False. + */ +template +bool check_all_zero_entries(const T* test, int m, int n); + +/* + * @brief In-place transposition for nxk matrix ref. + * @params n number of rows in output (number of columns in input) + * @params k number of columns in output (number of rows in input) + */ +template +void transpose_matrix(T* ref, int n, int k); +} // namespace fbgemm2 -- cgit v1.2.3