Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-03-18 09:35:08 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-18 09:38:29 +0300
commit1351790c8ca6418ab36697274c4d6f3cea3c140c (patch)
treef5ccb122c0b7e7223c43df4d5fa279aee89696a9
parent6011ce3b0c1fccee549e85b37e475c7a734ad742 (diff)
Add the Naive bfloat16 implementation based on MKL
Summary: Add the Naive bfloat16 implemenetation based on MKL. For this Naive bfloat16 implementation for C += A * B (A, B, and C are all bfloat16 type), we do the following three steps: 1. Convert bfloat16 A, B, C to fp32; 2. Call cblas_sgemm from MKL/BLAS; 3. Convert fp32 C back to bfloat16 C. Reviewed By: jspark1105 Differential Revision: D14391444 fbshipit-source-id: 1147dd2a18c4bbdec6c15f1d0f15d698d3741afe
-rw-r--r--src/RefImplementations.cc39
-rw-r--r--src/RefImplementations.h19
2 files changed, 58 insertions, 0 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 5f1277f..72ef93f 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -6,6 +6,8 @@
*/
#include "RefImplementations.h"
+#include "fbgemm/Types.h"
+
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -166,6 +168,43 @@ void matmul_fp_ref(
}
}
+void cblas_sgemm_ref(
+ const matrix_op_t transa,
+ const matrix_op_t transb,
+ const int m,
+ const int n,
+ const int k,
+ float alpha,
+ const float* Afp32,
+ int lda,
+ const float* Bfp32,
+ int ldb,
+ float beta,
+ float* Cfp32,
+ int ldc
+ ) {
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ float sum = 0;
+ for (int p = 0; p < k; ++p) {
+ float a =
+ (transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p]
+ : Afp32[p * lda + i]);
+ float b =
+ (transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j]
+ : Bfp32[j * ldb + p]);
+ sum += a * b;
+ }
+ if (beta == 0) {
+ Cfp32[i * ldc + j] = alpha * sum;
+ } else {
+ Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j];
+ }
+ }
+ }
+}
+
+
void row_offsets_u8acc32_ref(
int M,
int K,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index 62f17e9..117c8a1 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -108,6 +108,25 @@ void FBGEMM_API matmul_fp_ref(
float* Cfp32);
/**
+ * @brief Reference implementation of cblas_sgemm in MKL/BLAS.
+ */
+void FBGEMM_API cblas_sgemm_ref(
+ const matrix_op_t transa,
+ const matrix_op_t transb,
+ const int m,
+ const int n,
+ const int k,
+ float alpha,
+ const float* Afp32,
+ int lda,
+ const float* Bfp32,
+ int ldb,
+ float beta,
+ float* Cfp32,
+ int ldc
+ );
+
+/**
* @brief Reference implementation to compute row_offsets (sums of rows of A).
*/
FBGEMM_API void row_offsets_u8acc32_ref(