diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | bench/ConvUnifiedBenchmark.cc | 309 | ||||
-rw-r--r-- | bench/Depthwise3DBenchmark.cc | 2 | ||||
-rw-r--r-- | bench/DepthwiseBenchmark.cc | 2 | ||||
-rw-r--r-- | include/fbgemm/ConvUtils.h | 2 | ||||
-rw-r--r-- | include/fbgemm/Fbgemm.h | 107 | ||||
-rw-r--r-- | include/fbgemm/FbgemmI8DepthwiseAvx2.h | 175 | ||||
-rw-r--r-- | include/fbgemm/Utils.h | 5 | ||||
-rw-r--r-- | src/FbgemmConv.cc | 222 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.cc | 2 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.h | 4 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 10 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 193 | ||||
-rw-r--r-- | src/PackWeightMatrixForGConv.cc | 4 | ||||
-rw-r--r-- | src/PackWeightsForConv.cc | 71 | ||||
-rw-r--r-- | src/RefImplementations.cc | 28 | ||||
-rw-r--r-- | src/RefImplementations.h | 31 | ||||
-rw-r--r-- | test/I8DepthwiseTest.cc | 2 | ||||
-rw-r--r-- | test/Im2ColFusedRequantizeTest.cc | 4 | ||||
-rw-r--r-- | test/UniConvPackingTest.cc | 148 |
20 files changed, 1205 insertions, 118 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 80de824..b575e17 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/ExecuteKernelU8S8.cc src/Fbgemm.cc src/FbgemmFP16.cc + src/FbgemmConv.cc src/FbgemmI8Spmdm.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -42,6 +43,7 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/PackAWithQuantRowOffset.cc src/PackAWithRowOffset.cc src/PackWeightMatrixForGConv.cc + src/PackWeightsForConv.cc src/QuantUtils.cc src/RefImplementations.cc src/Utils.cc) diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc new file mode 100644 index 0000000..59079c7 --- /dev/null +++ b/bench/ConvUnifiedBenchmark.cc @@ -0,0 +1,309 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <chrono> +#include <cmath> +#include <iomanip> +#include <iostream> +#include <random> +#include <vector> + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +// 2D conv shapes +vector<conv_param_t<2>> shapes_2d = { + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, + // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right + // 2D convolutions + // regular + conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // groupwise + conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + + // DW + conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}), +}; + +// 3D conv shapes +vector<conv_param_t<3>> shapes_3d = { + // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w}, + // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right} + // Regular + conv_param_t<3>(1, 64, 64, {32, 56, 56}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), + // Depthwise + conv_param_t<3>(1, 64, 64, {32, 56, 56}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}) +}; + +template <int SPATIAL_DIM, typename Acc_t> +void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { + bool flush = true; + std::vector<char> llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + + string header = "MB, IC, OC, "; + if (SPATIAL_DIM == 3) { + header += "IT, "; + } + header += "IH, IW, G, "; + if (SPATIAL_DIM == 3) { + header += "KT, "; + } + header += "KH, KW, "; + if (SPATIAL_DIM == 3) { + header += "stride_t, "; + } + header += "stride_h, stride_w, "; + if (SPATIAL_DIM == 3) { + header += "pad_t, "; + } + header += "pad_h, pad_w, "; + + header += "Type, M, N, K, "; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << header << "Im2Col (ms), " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "fbgemmPacked (ms), " + << "Total (ms), " + << "GOPS" << endl; +#else + cout << setw(6) << header << setw(5) << "GOPS" << endl; +#endif + + chrono::time_point<chrono::high_resolution_clock> begin, end; + + for (auto conv_p : shapes) { + if (conv_p.IC % conv_p.G != 0 || conv_p.OC % conv_p.G != 0) { + // invalid shapes + continue; + } + int im_in_dim = accumulate( + conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies<int>()); + aligned_vector<uint8_t> Aint8(conv_p.MB * im_in_dim * conv_p.IC); + + int kernel_dim = + accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies<int>()); + aligned_vector<int8_t> Bint8( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + int im_out_dim = accumulate( + conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies<int>()); + aligned_vector<int32_t> Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); + aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size()); + aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_fb2(Cint32_ref.size(), 0); + aligned_vector<int32_t> Cint32_fb2(Cint32_ref.size()); + + // A matrix (input activations) + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + // B matrix (weights) + randFill<int8_t>(Bint8, -4, 4); + aligned_vector<int32_t> Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_point = 5; + + aligned_vector<float> Bfp32(Bint8.begin(), Bint8.end()); + + // reference implementation + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + // matrix dimensions after im2col + int MDim = conv_p.MB * im_out_dim; + int NDim = conv_p.OC / conv_p.G; + int KDim = kernel_dim * conv_p.IC; + int KDimPerGroup = KDim / conv_p.G; + + int OC_per_G = conv_p.OC / conv_p.G; + + // computing row offset + vector<int32_t> row_offsets(MDim); + vector<uint8_t> Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + // computing column offset + vector<int32_t> col_offsets(conv_p.OC); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + KDimPerGroup, + OC_per_G, + OC_per_G, + Bint8.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + Aint8_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } + + double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim; + double ttot = 0.0; + string runType; + + PackWeightsForConv<SPATIAL_DIM> packedB(conv_p, Bint8.data()); + + // no-op output process objects + DoNothing<> doNothingObj{}; + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsets.data(), + nullptr, // bias + conv_p.OC, + conv_p.G); + + runType = "UniConv"; + ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double im2col_time = 0.0; + double total_im2col_time = 0.0; + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + fbgemmConv( + conv_p, + Aint8.data(), + packedB, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + + cout << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC << ", "; + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.IN_DIM[i] << ", "; + } + cout << conv_p.G << ", "; + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.K[i] << ", "; + } + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.stride[i] << ", "; + } + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.pad[i] << ", "; + } + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) << 0 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; + + compare_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * conv_p.G, + NDim * conv_p.G, + 5); + } // shapes +} + +int main() { +#ifdef _OPENMP + // Use 1 thread unless OMP_NUM_THREADS is explicit set. + const char* val = getenv("OMP_NUM_THREADS"); + if (val == nullptr || !*val) { + omp_set_num_threads(1); + } +#endif + // performance_test<int16_t>(); + performance_test<2, int32_t>(shapes_2d); + performance_test<3, int32_t>(shapes_3d); + return 0; +} diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc index c5f8ed9..0efdcac 100644 --- a/bench/Depthwise3DBenchmark.cc +++ b/bench/Depthwise3DBenchmark.cc @@ -20,7 +20,7 @@ #include "AlignedVec.h" #include "BenchUtils.h" #include "fbgemm/Utils.h" -#include "src/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "src/RefImplementations.h" using namespace std; diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc index 780d83c..96921a1 100644 --- a/bench/DepthwiseBenchmark.cc +++ b/bench/DepthwiseBenchmark.cc @@ -18,7 +18,7 @@ #include "AlignedVec.h" #include "BenchUtils.h" #include "fbgemm/Utils.h" -#include "src/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "src/RefImplementations.h" using namespace std; diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h index 1c8251e..11f3dcc 100644 --- a/include/fbgemm/ConvUtils.h +++ b/include/fbgemm/ConvUtils.h @@ -101,7 +101,7 @@ struct conv_param_t { std::to_string(stride[d]) + ", "; } for (int d = 0; d < SPATIAL_DIM * 2; ++d) { - out += "pad_" + dim_string[3 - (SPATIAL_DIM % 3) + d] + ":" + + out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" + std::to_string(pad[d]); if (d < SPATIAL_DIM * 2 - 1) { out += ", "; diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 720f681..721b12f 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -18,6 +18,7 @@ #include "FbgemmBuild.h" #include "FbgemmI8Spmdm.h" #include "QuantUtilsAvx2.h" +#include "FbgemmI8DepthwiseAvx2.h" #include "Types.h" #include "Utils.h" @@ -524,6 +525,62 @@ class FBGEMM_API PackWeightMatrixForGConv { }; /** + * @brief A container class to keep packed weight tensor for convolution. + * The source tensor should already be quantized. + * + * @tparam SPATIAL_DIM is equal to 2 for 2D convolutions and 3 for 3D + * convolutions. Default value is 2. + * @tparam T is the datatype for source tensor. Default value is int8. + * @tparam accT is the datatype to accumulate into. Default value is int32. + */ +template < + int SPATIAL_DIM = 2, + typename T = std::int8_t, + typename accT = std::int32_t> +class FBGEMM_API PackWeightsForConv { + public: + using This = PackWeightsForConv<SPATIAL_DIM, T, accT>; + using inpType = T; + using accType = accT; + + PackWeightsForConv() = delete; // no default constructor + + PackWeightsForConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const inpType* sdata, + const BlockingFactors* blocking_params = nullptr); + + std::shared_ptr<PackBMatrix<T, accT>> getPackedWForIm2col() { + return W_im2col_packed_; + } + + std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() { + return W_dw_2D_packed_; + } + + std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() { + return W_dw_3D_packed_; + } + + std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> + getPackedWForGroupwise() { + return W_gconv_packed_; + } + + private: + // Packed weights if we use im2col based convolution implementation + std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_; + // Packed weights if we use 2D depthwise convolution implementation + std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_; + // Packed weights if we use 3D depthwise convolution implementation + std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_; + // Packed weights if we use groupwise (small channels per group) convolution + // implementation + std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> + W_gconv_packed_; +}; + +/** * @brief Matrix packed for the first input matrix in GEMM (usually activation), * and row offsets used for requantization is computed during packing. * Im2col is fused with packing here. The source matrix is already @@ -1002,6 +1059,7 @@ template < typename nextOPType = DoNothing<outT, outT>> class FBGEMM_API ReQuantizeOutput { public: + static constexpr int RELU_FUSED = FUSE_RELU; using outType = outT; using inpType = inT; /** @@ -1056,12 +1114,18 @@ class FBGEMM_API ReQuantizeOutput { const float* getCMultiplier() const { return C_multiplier_; } + std::int32_t getAZeroPoint() const { + return Aq_zero_point_; + } std::int32_t getCZeroPoint() const { return C_zero_point_; } const std::int32_t* getBZeroPoint() const { return Bq_zero_point_; } + const std::int32_t* getRowOffsets() const { + return q_row_offsets_; + } const std::int32_t* getColOffsets() const { return q_col_offsets_; } @@ -1072,6 +1136,10 @@ class FBGEMM_API ReQuantizeOutput { return ncols_; } + void setRowOffsets(const std::int32_t* row_offsets) { + q_row_offsets_ = row_offsets; + } + private: nextOPType& nextop_; const float* C_multiplier_; @@ -1273,6 +1341,12 @@ void convDepthwiseSeparable( const processOutputType& output); /** + * @brief Is this depthwise convolution optimized? + */ +template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t> +bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p); + +/** * @brief Is this groupwise convolution supported? */ template <int SPATIAL_DIM> @@ -1351,4 +1425,37 @@ static void fbgemmGetRange( } } +/** + * @brief Performs convolution using fastest path available. + * + * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions. + */ +template < + typename processOutputType, + int SPATIAL_DIM = 2, + typename ACC_T = std::int32_t> +FBGEMM_API int fbgemmConv( + const conv_param_t<SPATIAL_DIM>& conv_p, + const std::uint8_t* activations, + PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, + typename processOutputType::outType* out, + std::int32_t* outBuffer, + processOutputType& outProcess, + int thread_id, + int num_threads, + const BlockingFactors* blocking_params = nullptr); + +/** + * @brief Returns which fast path to take + * + * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions. + * + * @return optimized_conv_t::depthwise, optimized_conv_t::groupwise or + * optimized_conv_t::im2col + * + */ +template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t> +FBGEMM_API optimized_conv_t +ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p); + } // namespace fbgemm diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h new file mode 100644 index 0000000..432687c --- /dev/null +++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef I8DEPTHWISE_H +#define I8DEPTHWISE_H +#include <cstdint> +#include "fbgemm/FbgemmBuild.h" + +namespace fbgemm { + +// KERNEL_PROD is the product of all kernels. +// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3. +template <int KERNEL_PROD> +class FBGEMM_API PackedDepthWiseConvMatrix { + public: + // smat in RSG layout + PackedDepthWiseConvMatrix(int K, const std::int8_t* smat); + virtual ~PackedDepthWiseConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + private: + int K_; + std::int8_t* pmat_; +}; // Packed3x3ConvMatrix + +using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>; +using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * @params A The input image in NHWK layout + * @params Bp The pre-packed filter + */ +FBGEMM_API void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const Packed3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization. + * + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const Packed3x3ConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, + int thread_id = 0, + int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization and uses per-channel quantization. + * + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const Packed3x3ConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, + int thread_id = 0, + int num_threads = 1); + +FBGEMM_API void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const Packed3x3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const Packed3x3x3ConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, + int thread_id = 0, + int num_threads = 1); + +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, + int thread_id = 0, + int num_threads = 1); + +} // namespace fbgemm + +#endif diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 58232e7..ef1d4ab 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -32,6 +32,11 @@ enum class matrix_op_t { NoTranspose, Transpose }; enum class inst_set_t { anyarch, avx2, avx512 }; /** + * @brief Typed enum for optimized paths for convolutions + */ +enum class optimized_conv_t { depthwise, groupwise, im2col }; + +/** * @brief Typed enum for implementation type. * * ref is reference and opt is optimized. diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc new file mode 100644 index 0000000..5db63f6 --- /dev/null +++ b/src/FbgemmConv.cc @@ -0,0 +1,222 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <algorithm> +#include <iostream> +#include <vector> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +template <int SPATIAL_DIM, typename ACC_T> +bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { + // Note: Depthwise convolutions (both 2D and 3D) are optimized for the most + // common case. + return std::is_same<ACC_T, std::int32_t>::value && conv_p.G == conv_p.IC && + conv_p.G == conv_p.OC && conv_p.G % 8 == 0 && + std::all_of( + conv_p.stride.begin(), + conv_p.stride.end(), + [](int i) { return i == 1 || i == 2; }) && + std::all_of( + conv_p.K.begin(), conv_p.K.end(), [](int i) { return i == 3; }) && + std::all_of( + conv_p.dilation.begin(), + conv_p.dilation.end(), + [](int i) { return i == 1; }) && + std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [](int i) { + return i == 1; + }); +} + +template <int SPATIAL_DIM, typename ACC_T> +optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { + if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { + return optimized_conv_t::depthwise; + } else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) { + return optimized_conv_t::groupwise; + } else { + return optimized_conv_t::im2col; + } +} + +template <typename processOutputType, int SPATIAL_DIM, typename ACC_T> +int fbgemmConv( + const conv_param_t<SPATIAL_DIM>& conv_p, + const std::uint8_t* activations, + PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, + typename processOutputType::outType* out, + std::int32_t* outBuffer, + processOutputType& outProcess, + int thread_id, + int num_threads, + const BlockingFactors* blocking_params) { + static_assert( + SPATIAL_DIM == 2 || SPATIAL_DIM == 3, + "Only 2D and 3D convolutions are supported"); + switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { + case optimized_conv_t::depthwise: { + // 2D and 3D depthwise fast path + // std::cout << "Depthwise fast path" << std::endl; + const std::int32_t* B_zero_point = outProcess.getBZeroPoint(); + const float* C_multiplier = outProcess.getCMultiplier(); + if (SPATIAL_DIM == 3) { + static_assert( + std::is_same<typename processOutputType::outType, std::uint8_t>:: + value, + "For depthwise, only requantized output is supported"); + depthwise_3x3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // T + conv_p.IN_DIM[1], // H + conv_p.IN_DIM[2], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_t + conv_p.stride[1], // stride_h + conv_p.stride[2], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWFor3DDW()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + thread_id, + num_threads); + } else { + depthwise_3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // H + conv_p.IN_DIM[1], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_h + conv_p.stride[1], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWFor2DDW()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + thread_id, + num_threads); + } + break; + } + case optimized_conv_t::groupwise: { + // optimized groupwise convolution + // std::cout << "Groupwise fast path" << std::endl; + assert( + SPATIAL_DIM == 2 && "Only 2D groupwise convolutions are supported"); + std::vector<int32_t> row_offset_buf( + rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p)); + outProcess.setRowOffsets(row_offset_buf.data()); + fbgemmGroupwiseConv( + conv_p, + activations, + outProcess.getAZeroPoint(), + row_offset_buf.data(), + *(packed_weights.getPackedWForGroupwise()), + out, + outBuffer, + outProcess, + thread_id, + num_threads); + break; + } + case optimized_conv_t::im2col: { + // All other convolutions go through im2col-based implementation + // std::cout << "Im2col path" << std::endl; + std::vector<int32_t> row_offset_buf( + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize()); + + const std::int32_t* b_zero_point = outProcess.getBZeroPoint(); + bool b_symmetric = b_zero_point[0] == 0; + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA( + conv_p, + activations, + nullptr, /* buffer for packed matrix */ + outProcess.getAZeroPoint(), + row_offset_buf.data(), + b_symmetric, + blocking_params); + + outProcess.setRowOffsets(row_offset_buf.data()); + fbgemmPacked( + packA, + *(packed_weights.getPackedWForIm2col()), + out, + outBuffer, + conv_p.OC, + outProcess, + thread_id, + num_threads, + blocking_params); + break; + } + } // switch + + return 0; +} + +#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ + template int fbgemmConv( \ + const conv_param_t<SPATIAL_DIM>& conv_p, \ + const std::uint8_t* activations, \ + PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \ + std::uint8_t* out, \ + std::int32_t* outBuffer, \ + ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + int thread_id, \ + int num_threads, \ + const BlockingFactors* blocking_params); + +#define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3); + +#define INSTANTIATE_RELU(ACC_T, Q_GRAN) \ + INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \ + INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, false); + +#define INSTANTIATE_Q_GRANS(ACC_T) \ + INSTANTIATE_RELU(ACC_T, QuantizationGranularity::TENSOR); \ + INSTANTIATE_RELU(ACC_T, QuantizationGranularity::GROUP); \ + INSTANTIATE_RELU(ACC_T, QuantizationGranularity::OUT_CHANNEL); + +INSTANTIATE_Q_GRANS(std::int32_t); + +#undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_RELU +#undef INSTANTIATE_SPATIAL_DIM +#undef INSTANTIATE_BASE + +template bool takeDepthWiseFastPath<2, std::int32_t>( + const conv_param_t<2>& conv_p); +template bool takeDepthWiseFastPath<3, std::int32_t>( + const conv_param_t<3>& conv_p); +template bool takeDepthWiseFastPath<2, std::int16_t>( + const conv_param_t<2>& conv_p); +template bool takeDepthWiseFastPath<3, std::int16_t>( + const conv_param_t<3>& conv_p); + +template optimized_conv_t ConvFastPath<2, std::int32_t>( + const conv_param_t<2>& conv_p); +template optimized_conv_t ConvFastPath<3, std::int32_t>( + const conv_param_t<3>& conv_p); +template optimized_conv_t ConvFastPath<2, std::int16_t>( + const conv_param_t<2>& conv_p); +template optimized_conv_t ConvFastPath<3, std::int16_t>( + const conv_param_t<3>& conv_p); + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 2620e43..ee39faf 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -4,7 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include <algorithm> // for min and max #include <cassert> diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h index 069ff77..e2730df 100644 --- a/src/FbgemmI8DepthwiseAvx2.h +++ b/src/FbgemmI8DepthwiseAvx2.h @@ -4,7 +4,8 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#pragma once +#ifndef I8DEPTHWISE_H +#define I8DEPTHWISE_H #include <cstdint> #include "fbgemm/FbgemmBuild.h" @@ -170,3 +171,4 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( int num_threads = 1); } // namespace fbgemm +#endif diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 4e539f2..1e6324e 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -36,10 +36,12 @@ using jit_rowoffset_kernel_fp = void (*)( int32_t width, int32_t* row_offset); -template <typename accT = int32_t> +template <int SPATIAL_DIM = 2, typename accT = int32_t> class GenConvKernel { public: - GenConvKernel(const conv_param_t<>& conv_param, std::int32_t a_zero_point) + GenConvKernel( + const conv_param_t<SPATIAL_DIM>& conv_param, + std::int32_t a_zero_point) : WRegs_avx2_{x86::ymm0, x86::ymm1, x86::ymm2, @@ -119,11 +121,11 @@ class GenConvKernel { ~GenConvKernel() {} template <inst_set_t instSet> - jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param); + jit_conv_kernel_fp getOrCreate(const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> jit_rowoffset_kernel_fp getOrCreateRowOffset( - const conv_param_t<>& conv_param); + const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> void createVector16BitOne(asmjit::X86Emitter* a); diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 0032e72..e789695 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -21,24 +21,25 @@ namespace fbgemm { using namespace std; -template <typename accT> -thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_; +template <int SPATIAL_DIM, typename accT> +thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_; -template <typename accT> -thread_local asmjit::CodeHolder GenConvKernel<accT>::code_; +template <int SPATIAL_DIM, typename accT> +thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_; -template <typename accT> +template <int SPATIAL_DIM, typename accT> thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - GenConvKernel<accT>::codeCache_; + GenConvKernel<SPATIAL_DIM, accT>::codeCache_; -template <typename accT> +template <int SPATIAL_DIM, typename accT> thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - GenConvKernel<accT>::codeCacheRowOffset_; + GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; namespace x86 = asmjit::x86; +template <int SPATIAL_DIM> void calculateRowOffsets( - const conv_param_t<>& conv_param, + const conv_param_t<SPATIAL_DIM>& conv_param, const uint8_t* activations, int32_t* rowOffsetBuf, int32_t a_zero_point, @@ -72,8 +73,9 @@ void calculateRowOffsets( } } +template <int SPATIAL_DIM = 2> tuple<bool, int, int, int> getKernelSig( - const conv_param_t<>& conv_param, + const conv_param_t<SPATIAL_DIM>& conv_param, bool isAZeroPointZero) { int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; @@ -82,18 +84,18 @@ tuple<bool, int, int, int> getKernelSig( return kernelSig; } -template <typename accT = int32_t> +template <int SPATIAL_DIM = 2, typename accT = int32_t> jit_conv_kernel_fp getOrCreateConvKernel( - const conv_param_t<>& conv_param, + const conv_param_t<SPATIAL_DIM>& conv_param, int a_zero_point) { // Note: Wrong code is generated if it's not one of the supported convolution - assert(fbgemmOptimizedGConv<2>(conv_param)); + assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<accT>::codeCache_.find(kernelSig) != - GenConvKernel<accT>::codeCache_.end()) { - return GenConvKernel<accT>::codeCache_[kernelSig]; + if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) != + GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) { + return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig]; } else { - auto genObj = GenConvKernel<accT>(conv_param, a_zero_point); + auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); // TODO: Instruction set based dispatch return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); } @@ -101,7 +103,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> -void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( asmjit::X86Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains @@ -112,7 +114,7 @@ void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( asmjit::X86Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] @@ -122,7 +124,7 @@ void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>( } template <> template <> -void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm destReg) { // make destReg all zeros @@ -140,7 +142,7 @@ void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( asmjit::X86Emitter* a) { asmjit::X86Gp permute_const_reg = a->gpzRef(12); asmjit::X86Xmm const_reg_xmm = x86::xmm10; @@ -157,7 +159,7 @@ void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( asmjit::X86Emitter* a) { if (C_per_G_ == 4) { // store with permutation @@ -168,7 +170,7 @@ void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a, int offset) { // store @@ -195,8 +197,9 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // load weights for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_; ++s) { @@ -221,7 +224,7 @@ void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg) { @@ -232,7 +235,7 @@ void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); @@ -242,7 +245,7 @@ void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg) { @@ -263,7 +266,7 @@ void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg, @@ -315,7 +318,7 @@ void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( asmjit::X86Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { @@ -381,7 +384,7 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( asmjit::X86Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); @@ -395,8 +398,9 @@ void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // top-left corner code if (c_offset == 0) { // zero out the results register @@ -554,8 +558,9 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -620,8 +625,9 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -707,8 +713,9 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // bottom-left corner // we updating the last row a->mov(scratchReg1_, H_R_); @@ -898,8 +905,9 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a, int c_offset) { +void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); asmjit::Label LoopW = a->newLabel(); @@ -1000,8 +1008,8 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( template <> template <> -jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( - const conv_param_t<>& conv_param) { +jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( + const conv_param_t<2>& conv_param) { code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -1108,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a) { // top-left corner code // zero out the results register @@ -1204,7 +1212,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -1247,7 +1255,7 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1317,7 +1325,7 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a) { // bottom-left corner // zero out @@ -1420,7 +1428,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( +void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( asmjit::X86Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1480,8 +1488,8 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( template <> template <> jit_rowoffset_kernel_fp -GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( - const conv_param_t<>& conv_param) { +GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( + const conv_param_t<2>& conv_param) { code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -1598,7 +1606,6 @@ void fbgemmGroupwiseConvBase_( const processOutputType& outProcess, int thread_id, int num_threads) { - int MB = conv_param.MB; int H = conv_param.OUT_DIM[0]; int W = conv_param.OUT_DIM[1]; @@ -1608,14 +1615,14 @@ void fbgemmGroupwiseConvBase_( int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; - static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + assert(SPATIAL_DIM == 2 && "3D groupwise conv not supported yet"); int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr; if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) { assert(G % 8 == 0); // generate convolution kernel jit_conv_kernel_fp fpConv = - getOrCreateConvKernel<>(conv_param, a_zero_point); + getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point); // generate row offset kernel jit_rowoffset_kernel_fp fpRowoffset = getOrCreateRowOffsetKernel(conv_param, a_zero_point); @@ -1670,8 +1677,7 @@ void fbgemmGroupwiseConvBase_( for (int j = 0; j < gDelta; ++j) { // calculateRowOffsets( // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); - int32_t* rowOffsetForCurG = - rowOffsetTrDest + int32_t* rowOffsetForCurG = rowOffsetTrDest ? rowOffsetTrDest + ((g - gOuter) + j) * ih_iw : nullptr; // compare_buffers(rowOffsetBuf, rowOffsetForCurG, @@ -1735,7 +1741,7 @@ void fbgemmGroupwiseConvBase_( } } -} +} // namespace template < typename packed_W, @@ -1820,13 +1826,13 @@ void fbgemmGroupwiseConv( int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; - static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + assert(SPATIAL_DIM == 2 && "3D conv not supported yet"); int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr; assert(G % 8 == 0); // generate convolution kernel jit_conv_kernel_fp fpConv = - getOrCreateConvKernel<>(conv_param, a_zero_point); + getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point); // generate row offset kernel jit_rowoffset_kernel_fp fpRowoffset = getOrCreateRowOffsetKernel(conv_param, a_zero_point); @@ -2132,17 +2138,37 @@ void fbgemmGroupwiseConv( } // i loop } +// 3D not implemented yet +template <> +template <> +jit_conv_kernel_fp GenConvKernel<3, int32_t>::getOrCreate<inst_set_t::avx2>( + const conv_param_t<3>& /* unused */) { + assert(0 && "not implemented yet"); + return nullptr; +} + +template <> +template <> +jit_rowoffset_kernel_fp +GenConvKernel<3, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( + const conv_param_t<3>& conv_param) { + assert(0 && "not implemented yet"); + return nullptr; +} + +template <int SPATIAL_DIM = 2> jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( - const conv_param_t<>& conv_param, + const conv_param_t<SPATIAL_DIM>& conv_param, int a_zero_point) { // Note: Wrong code is generated if it's not one of the supported convolution - assert(fbgemmOptimizedGConv<2>(conv_param)); + assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) != - GenConvKernel<int32_t>::codeCacheRowOffset_.end()) { - return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig]; + if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find( + kernelSig) != + GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) { + return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig]; } else { - auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point); + auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); // TODO: Instruction set based dispatch return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); } @@ -2152,6 +2178,7 @@ template <int SPATIAL_DIM> int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { // row offset buffer should be a able to hold row offsets for however // number of groups we process at a time. + assert(SPATIAL_DIM == 2 && "Only 2D is supported currently"); if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; @@ -2186,29 +2213,35 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { } template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param); - -#define INSTANTIATE_BASE(RELU, Q_GRAN) \ - template void fbgemmGroupwiseConv( \ - const conv_param_t<2>& conv_param, \ - const uint8_t* activations, \ - int32_t a_zero_point, \ - std::int32_t* rowOffsetBuf, \ - PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \ - uint8_t* out, \ - int32_t* outBuffer, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ - int thread_id, \ +template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param); + +#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \ + template void fbgemmGroupwiseConv( \ + const conv_param_t<SPATIAL_DIM>& conv_param, \ + const uint8_t* activations, \ + int32_t a_zero_point, \ + std::int32_t* rowOffsetBuf, \ + PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \ + uint8_t* out, \ + int32_t* outBuffer, \ + const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + int thread_id, \ int num_threads); -#define INSTANTIATE_Q_GRANS(RELU) \ - INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL); +#define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \ + INSTANTIATE_BASE(RELU, Q_GRAN, 2); \ + INSTANTIATE_BASE(RELU, Q_GRAN, 3); + +#define INSTANTIATE_Q_GRANS(RELU) \ + INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \ + INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::GROUP); \ + INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::OUT_CHANNEL); INSTANTIATE_Q_GRANS(false); INSTANTIATE_Q_GRANS(true); #undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_SPATIAL_DIM #undef INSTANTIATE_BASE template void fbgemmGroupwiseConv( diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index 5870fa5..0fb0e2c 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -19,7 +19,7 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( const T* sdata, T* pdata) : trans_(trans), conv_param_(conv_param), sdata_(sdata) { - static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + assert(SPATIAL_DIM == 2 && "3D conv not supported yet"); if (!pdata) { bufAllocatedHere_ = true; @@ -111,4 +111,6 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { template class PackWeightMatrixForGConv<int8_t, int32_t, 2>; template class PackWeightMatrixForGConv<int8_t, int16_t, 2>; +template class PackWeightMatrixForGConv<int8_t, int32_t, 3>; +template class PackWeightMatrixForGConv<int8_t, int16_t, 3>; } // namespace fbgemm diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc new file mode 100644 index 0000000..c811144 --- /dev/null +++ b/src/PackWeightsForConv.cc @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <memory> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +template <int SPATIAL_DIM, typename T, typename accT> +PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( + const conv_param_t<SPATIAL_DIM>& conv_p, + const T* sdata, + const BlockingFactors* blocking_params) { + static_assert( + SPATIAL_DIM == 2 || SPATIAL_DIM == 3, + "Only 2D and 3D convolutions are supported"); + // Note: The following logic should *exactly* match with what we have in + // FbgemmConv.cc + switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { + case optimized_conv_t::depthwise: { + if (SPATIAL_DIM == 3) { + W_im2col_packed_ = nullptr; + W_dw_2D_packed_ = nullptr; + W_dw_3D_packed_ = + std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata); + W_gconv_packed_ = nullptr; + } else { + W_im2col_packed_ = nullptr; + W_dw_2D_packed_ = + std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata); + W_dw_3D_packed_ = nullptr; + W_gconv_packed_ = nullptr; + } + break; + } + case optimized_conv_t::groupwise: { + W_im2col_packed_ = nullptr; + W_dw_2D_packed_ = nullptr; + W_dw_3D_packed_ = nullptr; + W_gconv_packed_ = + std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( + matrix_op_t::NoTranspose, conv_p, sdata, nullptr); + break; + } + case optimized_conv_t::im2col: { + int NDim = conv_p.OC / conv_p.G; + int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; + W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>( + matrix_op_t::NoTranspose, + KDim, + NDim, + sdata, + NDim, + nullptr, + conv_p.G, + blocking_params); + W_dw_2D_packed_ = nullptr; + W_dw_3D_packed_ = nullptr; + W_gconv_packed_ = nullptr; + break; + } + } // switch +} + +template class PackWeightsForConv<2, int8_t, int32_t>; +template class PackWeightsForConv<3, int8_t, int32_t>; + +} // namespace fbgemm diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 72ef93f..b4b0c2b 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -285,8 +285,9 @@ int32_t clip_16bit(int32_t x) { * A: NHWC: NH_0W_0 x C_0 * Ao: NHWC: NH_1W_1 x G RS C_0/G */ +template <> void im2col_ref( - const conv_param_t<>& conv_p, + const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, uint8_t* Ao) { @@ -346,7 +347,8 @@ void im2col_ref( * A: NHWC: NT_0H_0W_0 x C_0 * Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G */ -void im2col3d_ref( +template <> +void im2col_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, @@ -422,8 +424,10 @@ void im2col3d_ref( } // for each n } +// 2D Conv +template <> void conv_ref( - const conv_param_t<>& conv_p, + const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, const int8_t* B, @@ -471,7 +475,9 @@ void conv_ref( } // for each n } -void conv3d_ref( +// 3D Conv +template <> +void conv_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, @@ -531,10 +537,12 @@ void conv3d_ref( } // for each n } +template <int SPATIAL_DIM> void transposeConvWeights( - const conv_param_t<>& conv_p, + const conv_param_t<SPATIAL_DIM>& conv_p, const std::int8_t* src, std::int8_t* dest) { + assert(SPATIAL_DIM == 2 && "Only 2D supported currently"); int R = conv_p.K[0]; int S = conv_p.K[1]; int G = conv_p.G; @@ -956,4 +964,14 @@ void depthwise_3x3x3_per_channel_quantization_pad_1_ref( } }; +template void transposeConvWeights( + const conv_param_t<2>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + +template void transposeConvWeights( + const conv_param_t<3>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + } // namespace fbgemm diff --git a/src/RefImplementations.h b/src/RefImplementations.h index 117c8a1..082bdf1 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -180,15 +180,9 @@ int32_t clip_16bit(int32_t x); * The filters B are assumed to be in RSCK format. * The output C is assumed to be in NHoWoC format. */ +template <int SPATIAL_DIM = 2> FBGEMM_API void conv_ref( - const conv_param_t<>& conv_p, - const std::uint8_t* A, - std::int32_t A_zero_point, - const std::int8_t* B, - std::int32_t* C); - -FBGEMM_API void conv3d_ref( - const conv_param_t<3>& conv_p, + const conv_param_t<SPATIAL_DIM>& conv_p, const std::uint8_t* A, std::int32_t A_zero_point, const std::int8_t* B, @@ -197,29 +191,26 @@ FBGEMM_API void conv3d_ref( /* * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. */ +template <int SPATIAL_DIM = 2> FBGEMM_API void transposeConvWeights( - const conv_param_t<>& conv_p, + const conv_param_t<SPATIAL_DIM>& conv_p, const std::int8_t* src, std::int8_t* dest); /* * @brief Reference implementation of im2col operation. + * + * For 2D: * The input A is assumed to be in NHiWiC format. * The output A is assumed to be in NHoWoRSC format. - */ -FBGEMM_API void im2col_ref( - const conv_param_t<>& conv_p, - const std::uint8_t* A, - std::int32_t A_zero_point, - std::uint8_t* Ao); - -/* - * @brief Reference implementation of im2col 3D operation. + * + * For 3D: * The input A is assumed to be in NTiHiWiC format. * The output A is assumed to be in NToHoWoK0K1K2C format. */ -FBGEMM_API void im2col3d_ref( - const conv_param_t<3>& conv_p, +template <int SPATIAL_DIM = 2> +FBGEMM_API void im2col_ref( + const conv_param_t<SPATIAL_DIM>& conv_p, const std::uint8_t* A, std::int32_t A_zero_point, std::uint8_t* Ao); diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index 7843aae..11bd625 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -14,7 +14,7 @@ #include "TestUtils.h" #include "bench/AlignedVec.h" #include "bench/BenchUtils.h" -#include "src/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "src/RefImplementations.h" using namespace std; diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc index 1ec2a06..d9c2f75 100644 --- a/test/Im2ColFusedRequantizeTest.cc +++ b/test/Im2ColFusedRequantizeTest.cc @@ -610,7 +610,7 @@ static void Im2col3DTest(bool b_symmetric) { // computing row offset vector<int32_t> row_offsets(MDim); vector<uint8_t> Aint8_im2col(MDim * KDim); - im2col3d_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); // computing column offset vector<int32_t> col_offsets(conv_p.G * NDim); @@ -625,7 +625,7 @@ static void Im2col3DTest(bool b_symmetric) { ncols_per_quant_group); } - conv3d_ref( + conv_ref( conv_p, Aint8.data(), Aint8_zero_point, diff --git a/test/UniConvPackingTest.cc b/test/UniConvPackingTest.cc new file mode 100644 index 0000000..77552af --- /dev/null +++ b/test/UniConvPackingTest.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <random> +#include <iostream> + + +#include <gtest/gtest.h> + +#include "QuantizationHelpers.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +namespace { + +// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad +class convPackingTest + : public testing::TestWithParam< + tuple<int, int, int, int, int, int, int, int, int, int>> {}; + +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + convPackingTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2}), // MB + ::testing::ValuesIn({16, 32}), // IC + ::testing::ValuesIn({16, 32}), // OC + ::testing::ValuesIn({17}), // IT + ::testing::ValuesIn({10, 30, 55}), // IH + ::testing::ValuesIn({10, 30, 55}), // IW + ::testing::ValuesIn({1, 4, 16}), // G + ::testing::ValuesIn({3, 7}), // kernel + ::testing::ValuesIn({1, 2}), // stride + ::testing::ValuesIn({1, 2}))); // pad + +/** + * Test for conv packing + */ +TEST_P(convPackingTest, packingTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + switch (ConvFastPath<2, int32_t>(conv_p_2d)) { + case optimized_conv_t::depthwise: { + ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + switch (ConvFastPath<3, int32_t>(conv_p_3d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_TRUE(false) << "groupwise are not supported for 3D"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) + << "2D depthwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr) + << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } +} |