diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-01-31 02:06:15 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-31 02:17:18 +0300 |
commit | 03a8fa506ee0822e3ab045f8e3df902abcfd2e74 (patch) | |
tree | 95a5b96dfad4707fd006a9e7e00fad9becd63c6d | |
parent | 79333308f5e2fc242727879dcd3de3536b6ffc39 (diff) |
Add threading for FBGEMM FP16
Summary: Add threading support for FBGEMM FP16 routines.
Reviewed By: dskhudia, jacobkahn
Differential Revision: D13792341
fbshipit-source-id: eb31a11340ac9fd0ee9b4f570d161e7c7e6a7602
-rw-r--r-- | bench/FP16Benchmark.cc | 39 | ||||
-rw-r--r-- | include/fbgemm/FbgemmFP16.h | 23 | ||||
-rw-r--r-- | src/FbgemmFP16.cc | 136 | ||||
-rw-r--r-- | test/FP16Test.cc | 16 |
4 files changed, 140 insertions, 74 deletions
diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc index efb043d..f3a2e2d 100644 --- a/bench/FP16Benchmark.cc +++ b/bench/FP16Benchmark.cc @@ -122,8 +122,22 @@ void performance_test() { C_ref.data(), n); #endif - cblas_gemm_compute( - matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data()); +#ifdef _OPENMP +#pragma omp parallel +#endif + { + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + cblas_gemm_compute( + matrix_op_t::NoTranspose, + m, + A.data(), + Bp, + beta, + C_fb.data(), + tid, + num_threads); + } #if defined(USE_MKL) || defined(USE_BLAS) // Compare results @@ -201,8 +215,25 @@ void performance_test() { } t_begin = chrono::system_clock::now(); - cblas_gemm_compute( - matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data()); + +#ifdef _OPENMP +#pragma omp parallel +#endif + { + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + cblas_gemm_compute( + matrix_op_t::NoTranspose, + m, + A.data(), + Bp, + beta, + C_fb.data(), + tid, + num_threads); + } + t_end = chrono::system_clock::now(); if (it >= 0) { diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index bebeb70..96deb49 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -192,14 +192,9 @@ class PackedGemmMatrixFP16 { const float* A, const PackedGemmMatrixFP16& Bp, const float beta, - float* C); - friend void cblas_gemm_compute( - const matrix_op_t transa, - const int m, - const float* A, - const PackedGemmMatrixFP16& Bp, - const float beta, - float* C); + float* C, + int thread_id, + int num_threads); }; /** @@ -211,13 +206,7 @@ extern void cblas_gemm_compute( const float* A, const PackedGemmMatrixFP16& Bp, const float beta, - float* C); -extern void cblas_gemm_compute( - const matrix_op_t transa, - const int m, - const float* A, - const PackedGemmMatrixFP16& Bp, - const float beta, - float* C); - + float* C, + int thread_id = 0, + int num_threads = 1); }; // namespace fbgemm diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 69680bc..d3d5c1f 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -6,6 +6,8 @@ */ #include "fbgemm/FbgemmFP16.h" +#include "fbgemm/Fbgemm.h" + #include <cpuinfo.h> #include <array> #include <utility> @@ -34,21 +36,24 @@ struct KernelInfo { using knl_ptr = funcptr_fp16; // optimized kernels to cover all cases static constexpr array<knl_ptr, 15> kernel = { - {nullptr, - gemmkernel_1x1_AVX2_fA0fB0fC0, - gemmkernel_2x1_AVX2_fA0fB0fC0, - gemmkernel_3x1_AVX2_fA0fB0fC0, - gemmkernel_4x1_AVX2_fA0fB0fC0, - gemmkernel_5x1_AVX2_fA0fB0fC0, - gemmkernel_6x1_AVX2_fA0fB0fC0, - gemmkernel_7x1_AVX2_fA0fB0fC0, - gemmkernel_8x1_AVX2_fA0fB0fC0, - gemmkernel_9x1_AVX2_fA0fB0fC0, - gemmkernel_10x1_AVX2_fA0fB0fC0, - gemmkernel_11x1_AVX2_fA0fB0fC0, - gemmkernel_12x1_AVX2_fA0fB0fC0, - gemmkernel_13x1_AVX2_fA0fB0fC0, - gemmkernel_14x1_AVX2_fA0fB0fC0}}; + { + nullptr, + gemmkernel_1x1_AVX2_fA0fB0fC0, + gemmkernel_2x1_AVX2_fA0fB0fC0, + gemmkernel_3x1_AVX2_fA0fB0fC0, + gemmkernel_4x1_AVX2_fA0fB0fC0, + gemmkernel_5x1_AVX2_fA0fB0fC0, + gemmkernel_6x1_AVX2_fA0fB0fC0, + gemmkernel_7x1_AVX2_fA0fB0fC0, + gemmkernel_8x1_AVX2_fA0fB0fC0, + gemmkernel_9x1_AVX2_fA0fB0fC0, + gemmkernel_10x1_AVX2_fA0fB0fC0, + gemmkernel_11x1_AVX2_fA0fB0fC0, + gemmkernel_12x1_AVX2_fA0fB0fC0, + gemmkernel_13x1_AVX2_fA0fB0fC0, + gemmkernel_14x1_AVX2_fA0fB0fC0 + } + }; // autotuned kernel splits for various cases m = 1:mb_max // may need re-autotuning for new uarch @@ -177,7 +182,7 @@ struct KernelInfo { {{ { 12, 9 }, { 10, 1 } } }, {{ { 12, 9 }, { 11, 1 } } }, {{ { 12, 10 }, { 0, 0 } } } - } + } }; }; constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel; @@ -190,7 +195,9 @@ FBGEMM_API void cblas_gemm_compute( const float* A, const PackedGemmMatrixFP16& Bp, const float beta, - float* C) { + float* C, + int thread_id, + int num_threads) { // ground truth assert(cpuinfo_initialize()); assert(cpuinfo_has_x86_fma3()); @@ -209,8 +216,13 @@ FBGEMM_API void cblas_gemm_compute( new std::array<float, 256 * 1024>()); GemmParams gp; - for (auto m0 = 0; m0 < m; m0 += mb_max) { - int mb = std::min(mb_max, m - m0); + + int i_begin, i_end; + // fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end); + i_begin = 0; + i_end = m; + for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) { + int mb = std::min(mb_max, i_end - m0); assert(mb < KernelInfo::partition.size()); for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) { // set up proper accumulation to avoid "Nan" problem @@ -249,46 +261,66 @@ FBGEMM_API void cblas_gemm_compute( gp.ldc = ldc * sizeof(C[0]); gp.b_block_cols = nbcol; gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]); + if ((n % Bp.blockColSize()) == 0) { - KernelInfo::kernel[kernel_nrows](&gp); + int jb_begin, jb_end; + fbgemmGetRange( + num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end); + gp.B += gp.k * Bp.blockColSize() * jb_begin; + gp.C += 8 * jb_begin; + gp.b_block_cols = jb_end - jb_begin; + if (gp.b_block_cols) { + KernelInfo::kernel[kernel_nrows](&gp); + } } else { int last_blk_col = nbcol * Bp.blockColSize(); if (nbcol) { - KernelInfo::kernel[kernel_nrows](&gp); + int jb_begin, jb_end; + fbgemmGetRange( + num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end); + gp.B += gp.k * Bp.blockColSize() * jb_begin; + gp.C += 8 * jb_begin; + gp.b_block_cols = jb_end - jb_begin; + if (gp.b_block_cols) { + KernelInfo::kernel[kernel_nrows](&gp); + } } - // leftover - int rem = n - last_blk_col; - assert(rem < kernel_ncols); - int b = (rem % simd_width) ? ((rem + simd_width) / simd_width) - : (rem / simd_width); - assert(b == 1); - if ((rem % simd_width) == 0) { - gp.B = &(Bp(k_ind, last_blk_col)); - gp.C = &C[m2 * ldc + last_blk_col]; - gp.b_block_cols = 1; - KernelInfo::kernel[kernel_nrows](&gp); - } else { - // small temporary buffer - float c_tmp[16 * 24] = {0}; - assert((16 * 24) > kernel_nrows * kernel_ncols); + // use one thread to handle the fringe cases + if (thread_id == num_threads - 1) { + // leftover + int rem = n - last_blk_col; + assert(rem < kernel_ncols); + int b = (rem % simd_width) ? ((rem + simd_width) / simd_width) + : (rem / simd_width); + assert(b == 1); + if ((rem % simd_width) == 0) { + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = &C[m2 * ldc + last_blk_col]; + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + } else { + // small temporary buffer + float c_tmp[16 * 24] = {0}; + assert((16 * 24) > kernel_nrows * kernel_ncols); - gp.B = &(Bp(k_ind, last_blk_col)); - gp.C = c_tmp; - gp.ldc = 8 * sizeof(C[0]); - gp.b_block_cols = 1; - KernelInfo::kernel[kernel_nrows](&gp); - for (int i = 0; i < kernel_nrows; i++) { - // Todo: use assembly - for (int j = last_blk_col; j < n; j++) { - assert( - i * 8 + (j - last_blk_col) < - sizeof(c_tmp) / sizeof(c_tmp[0])); - if (accum == 0) { - C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)]; - } else { - C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + - c_tmp[i * 8 + (j - last_blk_col)]; + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = c_tmp; + gp.ldc = 8 * sizeof(C[0]); + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + for (int i = 0; i < kernel_nrows; i++) { + // Todo: use assembly + for (int j = last_blk_col; j < n; j++) { + assert( + i * 8 + (j - last_blk_col) < + sizeof(c_tmp) / sizeof(c_tmp[0])); + if (accum == 0) { + C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)]; + } else { + C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + + c_tmp[i * 8 + (j - last_blk_col)]; + } } } } diff --git a/test/FP16Test.cc b/test/FP16Test.cc index 7890374..eb49086 100644 --- a/test/FP16Test.cc +++ b/test/FP16Test.cc @@ -6,6 +6,10 @@ */ #include <random> +#ifdef _OPENMP +#include <omp.h> +#endif + #include <gtest/gtest.h> #include "TestUtils.h" @@ -97,7 +101,17 @@ TEST_P(FBGemmFP16Test, Test) { // fbgemm fp16 PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); - cblas_gemm_compute(atrans, m, A.data(), Bp, beta, C.data()); + +#ifdef _OPENMP +#pragma omp parallel +#endif + { + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + cblas_gemm_compute( + atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads); + } // correctness check for (int i = 0; i < m; ++i) { |