diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-03-18 09:35:08 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-18 09:38:29 +0300 |
commit | 1351790c8ca6418ab36697274c4d6f3cea3c140c (patch) | |
tree | f5ccb122c0b7e7223c43df4d5fa279aee89696a9 /src | |
parent | 6011ce3b0c1fccee549e85b37e475c7a734ad742 (diff) |
Add the Naive bfloat16 implementation based on MKL
Summary:
Add the Naive bfloat16 implemenetation based on MKL.
For this Naive bfloat16 implementation for C += A * B (A, B, and C are all bfloat16 type), we do the following three steps:
1. Convert bfloat16 A, B, C to fp32;
2. Call cblas_sgemm from MKL/BLAS;
3. Convert fp32 C back to bfloat16 C.
Reviewed By: jspark1105
Differential Revision: D14391444
fbshipit-source-id: 1147dd2a18c4bbdec6c15f1d0f15d698d3741afe
Diffstat (limited to 'src')
-rw-r--r-- | src/RefImplementations.cc | 39 | ||||
-rw-r--r-- | src/RefImplementations.h | 19 |
2 files changed, 58 insertions, 0 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 5f1277f..72ef93f 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -6,6 +6,8 @@ */ #include "RefImplementations.h" +#include "fbgemm/Types.h" + #include <algorithm> #include <cassert> #include <cmath> @@ -166,6 +168,43 @@ void matmul_fp_ref( } } +void cblas_sgemm_ref( + const matrix_op_t transa, + const matrix_op_t transb, + const int m, + const int n, + const int k, + float alpha, + const float* Afp32, + int lda, + const float* Bfp32, + int ldb, + float beta, + float* Cfp32, + int ldc + ) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float sum = 0; + for (int p = 0; p < k; ++p) { + float a = + (transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p] + : Afp32[p * lda + i]); + float b = + (transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j] + : Bfp32[j * ldb + p]); + sum += a * b; + } + if (beta == 0) { + Cfp32[i * ldc + j] = alpha * sum; + } else { + Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j]; + } + } + } +} + + void row_offsets_u8acc32_ref( int M, int K, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index 62f17e9..117c8a1 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -108,6 +108,25 @@ void FBGEMM_API matmul_fp_ref( float* Cfp32); /** + * @brief Reference implementation of cblas_sgemm in MKL/BLAS. + */ +void FBGEMM_API cblas_sgemm_ref( + const matrix_op_t transa, + const matrix_op_t transb, + const int m, + const int n, + const int k, + float alpha, + const float* Afp32, + int lda, + const float* Bfp32, + int ldb, + float beta, + float* Cfp32, + int ldc + ); + +/** * @brief Reference implementation to compute row_offsets (sums of rows of A). */ FBGEMM_API void row_offsets_u8acc32_ref( |