Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-01-31 02:06:15 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-31 02:17:18 +0300
commit03a8fa506ee0822e3ab045f8e3df902abcfd2e74 (patch)
tree95a5b96dfad4707fd006a9e7e00fad9becd63c6d
parent79333308f5e2fc242727879dcd3de3536b6ffc39 (diff)
Add threading for FBGEMM FP16
Summary: Add threading support for FBGEMM FP16 routines. Reviewed By: dskhudia, jacobkahn Differential Revision: D13792341 fbshipit-source-id: eb31a11340ac9fd0ee9b4f570d161e7c7e6a7602
-rw-r--r--bench/FP16Benchmark.cc39
-rw-r--r--include/fbgemm/FbgemmFP16.h23
-rw-r--r--src/FbgemmFP16.cc136
-rw-r--r--test/FP16Test.cc16
4 files changed, 140 insertions, 74 deletions
diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc
index efb043d..f3a2e2d 100644
--- a/bench/FP16Benchmark.cc
+++ b/bench/FP16Benchmark.cc
@@ -122,8 +122,22 @@ void performance_test() {
C_ref.data(),
n);
#endif
- cblas_gemm_compute(
- matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data());
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+ cblas_gemm_compute(
+ matrix_op_t::NoTranspose,
+ m,
+ A.data(),
+ Bp,
+ beta,
+ C_fb.data(),
+ tid,
+ num_threads);
+ }
#if defined(USE_MKL) || defined(USE_BLAS)
// Compare results
@@ -201,8 +215,25 @@ void performance_test() {
}
t_begin = chrono::system_clock::now();
- cblas_gemm_compute(
- matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data());
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ cblas_gemm_compute(
+ matrix_op_t::NoTranspose,
+ m,
+ A.data(),
+ Bp,
+ beta,
+ C_fb.data(),
+ tid,
+ num_threads);
+ }
+
t_end = chrono::system_clock::now();
if (it >= 0) {
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
index bebeb70..96deb49 100644
--- a/include/fbgemm/FbgemmFP16.h
+++ b/include/fbgemm/FbgemmFP16.h
@@ -192,14 +192,9 @@ class PackedGemmMatrixFP16 {
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
- float* C);
- friend void cblas_gemm_compute(
- const matrix_op_t transa,
- const int m,
- const float* A,
- const PackedGemmMatrixFP16& Bp,
- const float beta,
- float* C);
+ float* C,
+ int thread_id,
+ int num_threads);
};
/**
@@ -211,13 +206,7 @@ extern void cblas_gemm_compute(
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
- float* C);
-extern void cblas_gemm_compute(
- const matrix_op_t transa,
- const int m,
- const float* A,
- const PackedGemmMatrixFP16& Bp,
- const float beta,
- float* C);
-
+ float* C,
+ int thread_id = 0,
+ int num_threads = 1);
}; // namespace fbgemm
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index 69680bc..d3d5c1f 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -6,6 +6,8 @@
*/
#include "fbgemm/FbgemmFP16.h"
+#include "fbgemm/Fbgemm.h"
+
#include <cpuinfo.h>
#include <array>
#include <utility>
@@ -34,21 +36,24 @@ struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
static constexpr array<knl_ptr, 15> kernel = {
- {nullptr,
- gemmkernel_1x1_AVX2_fA0fB0fC0,
- gemmkernel_2x1_AVX2_fA0fB0fC0,
- gemmkernel_3x1_AVX2_fA0fB0fC0,
- gemmkernel_4x1_AVX2_fA0fB0fC0,
- gemmkernel_5x1_AVX2_fA0fB0fC0,
- gemmkernel_6x1_AVX2_fA0fB0fC0,
- gemmkernel_7x1_AVX2_fA0fB0fC0,
- gemmkernel_8x1_AVX2_fA0fB0fC0,
- gemmkernel_9x1_AVX2_fA0fB0fC0,
- gemmkernel_10x1_AVX2_fA0fB0fC0,
- gemmkernel_11x1_AVX2_fA0fB0fC0,
- gemmkernel_12x1_AVX2_fA0fB0fC0,
- gemmkernel_13x1_AVX2_fA0fB0fC0,
- gemmkernel_14x1_AVX2_fA0fB0fC0}};
+ {
+ nullptr,
+ gemmkernel_1x1_AVX2_fA0fB0fC0,
+ gemmkernel_2x1_AVX2_fA0fB0fC0,
+ gemmkernel_3x1_AVX2_fA0fB0fC0,
+ gemmkernel_4x1_AVX2_fA0fB0fC0,
+ gemmkernel_5x1_AVX2_fA0fB0fC0,
+ gemmkernel_6x1_AVX2_fA0fB0fC0,
+ gemmkernel_7x1_AVX2_fA0fB0fC0,
+ gemmkernel_8x1_AVX2_fA0fB0fC0,
+ gemmkernel_9x1_AVX2_fA0fB0fC0,
+ gemmkernel_10x1_AVX2_fA0fB0fC0,
+ gemmkernel_11x1_AVX2_fA0fB0fC0,
+ gemmkernel_12x1_AVX2_fA0fB0fC0,
+ gemmkernel_13x1_AVX2_fA0fB0fC0,
+ gemmkernel_14x1_AVX2_fA0fB0fC0
+ }
+ };
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
@@ -177,7 +182,7 @@ struct KernelInfo {
{{ { 12, 9 }, { 10, 1 } } },
{{ { 12, 9 }, { 11, 1 } } },
{{ { 12, 10 }, { 0, 0 } } }
- }
+ }
};
};
constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
@@ -190,7 +195,9 @@ FBGEMM_API void cblas_gemm_compute(
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
- float* C) {
+ float* C,
+ int thread_id,
+ int num_threads) {
// ground truth
assert(cpuinfo_initialize());
assert(cpuinfo_has_x86_fma3());
@@ -209,8 +216,13 @@ FBGEMM_API void cblas_gemm_compute(
new std::array<float, 256 * 1024>());
GemmParams gp;
- for (auto m0 = 0; m0 < m; m0 += mb_max) {
- int mb = std::min(mb_max, m - m0);
+
+ int i_begin, i_end;
+ // fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end);
+ i_begin = 0;
+ i_end = m;
+ for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
+ int mb = std::min(mb_max, i_end - m0);
assert(mb < KernelInfo::partition.size());
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
@@ -249,46 +261,66 @@ FBGEMM_API void cblas_gemm_compute(
gp.ldc = ldc * sizeof(C[0]);
gp.b_block_cols = nbcol;
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
+
if ((n % Bp.blockColSize()) == 0) {
- KernelInfo::kernel[kernel_nrows](&gp);
+ int jb_begin, jb_end;
+ fbgemmGetRange(
+ num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
+ gp.B += gp.k * Bp.blockColSize() * jb_begin;
+ gp.C += 8 * jb_begin;
+ gp.b_block_cols = jb_end - jb_begin;
+ if (gp.b_block_cols) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ }
} else {
int last_blk_col = nbcol * Bp.blockColSize();
if (nbcol) {
- KernelInfo::kernel[kernel_nrows](&gp);
+ int jb_begin, jb_end;
+ fbgemmGetRange(
+ num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
+ gp.B += gp.k * Bp.blockColSize() * jb_begin;
+ gp.C += 8 * jb_begin;
+ gp.b_block_cols = jb_end - jb_begin;
+ if (gp.b_block_cols) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ }
}
- // leftover
- int rem = n - last_blk_col;
- assert(rem < kernel_ncols);
- int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
- : (rem / simd_width);
- assert(b == 1);
- if ((rem % simd_width) == 0) {
- gp.B = &(Bp(k_ind, last_blk_col));
- gp.C = &C[m2 * ldc + last_blk_col];
- gp.b_block_cols = 1;
- KernelInfo::kernel[kernel_nrows](&gp);
- } else {
- // small temporary buffer
- float c_tmp[16 * 24] = {0};
- assert((16 * 24) > kernel_nrows * kernel_ncols);
+ // use one thread to handle the fringe cases
+ if (thread_id == num_threads - 1) {
+ // leftover
+ int rem = n - last_blk_col;
+ assert(rem < kernel_ncols);
+ int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
+ : (rem / simd_width);
+ assert(b == 1);
+ if ((rem % simd_width) == 0) {
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = &C[m2 * ldc + last_blk_col];
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ } else {
+ // small temporary buffer
+ float c_tmp[16 * 24] = {0};
+ assert((16 * 24) > kernel_nrows * kernel_ncols);
- gp.B = &(Bp(k_ind, last_blk_col));
- gp.C = c_tmp;
- gp.ldc = 8 * sizeof(C[0]);
- gp.b_block_cols = 1;
- KernelInfo::kernel[kernel_nrows](&gp);
- for (int i = 0; i < kernel_nrows; i++) {
- // Todo: use assembly
- for (int j = last_blk_col; j < n; j++) {
- assert(
- i * 8 + (j - last_blk_col) <
- sizeof(c_tmp) / sizeof(c_tmp[0]));
- if (accum == 0) {
- C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
- } else {
- C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
- c_tmp[i * 8 + (j - last_blk_col)];
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = c_tmp;
+ gp.ldc = 8 * sizeof(C[0]);
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ for (int i = 0; i < kernel_nrows; i++) {
+ // Todo: use assembly
+ for (int j = last_blk_col; j < n; j++) {
+ assert(
+ i * 8 + (j - last_blk_col) <
+ sizeof(c_tmp) / sizeof(c_tmp[0]));
+ if (accum == 0) {
+ C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
+ } else {
+ C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
+ c_tmp[i * 8 + (j - last_blk_col)];
+ }
}
}
}
diff --git a/test/FP16Test.cc b/test/FP16Test.cc
index 7890374..eb49086 100644
--- a/test/FP16Test.cc
+++ b/test/FP16Test.cc
@@ -6,6 +6,10 @@
*/
#include <random>
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
#include <gtest/gtest.h>
#include "TestUtils.h"
@@ -97,7 +101,17 @@ TEST_P(FBGemmFP16Test, Test) {
// fbgemm fp16
PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
- cblas_gemm_compute(atrans, m, A.data(), Bp, beta, C.data());
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ cblas_gemm_compute(
+ atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads);
+ }
// correctness check
for (int i = 0; i < m; ++i) {