Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-12-09 21:13:02 +0300
committerYoung Jin Kim <youki@microsoft.com>2019-12-09 21:13:02 +0300
commit828a68b406ee245ee25e7c05922cebf99979de58 (patch)
tree6834e0f18976d00f62ab1efeb484d01ad0d97988
parenteb530cd5a5ff88367de84a56473c5a5ed7a5905a (diff)
parent576037639285dcc268d02b0547aa54e9140b5d33 (diff)
Merge branch 'upstream' into youki/mergemaster1206youki/mergemaster1206
-rw-r--r--CMakeLists.txt10
-rw-r--r--include/fbgemm/FbgemmConvert.h137
-rw-r--r--src/FbgemmBfloat16Convert.cc85
-rw-r--r--src/FbgemmBfloat16ConvertAvx2.cc64
-rw-r--r--src/FbgemmBfloat16ConvertAvx512.cc57
-rw-r--r--src/FbgemmFloat16Convert.cc88
-rw-r--r--src/FbgemmFloat16ConvertAvx2.cc73
-rw-r--r--src/FbgemmFloat16ConvertAvx512.cc73
-rw-r--r--test/Bfloat16ConvertTest.cc83
-rw-r--r--test/Float16ConvertTest.cc119
10 files changed, 787 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6d7f548..3ac70b5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -32,8 +32,10 @@ endif(MSVC)
set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/ExecuteKernelU8S8.cc
src/Fbgemm.cc
- src/FbgemmFP16.cc
+ src/FbgemmBfloat16Convert.cc
src/FbgemmConv.cc
+ src/FbgemmFP16.cc
+ src/FbgemmFloat16Convert.cc
src/FbgemmI64.cc
src/FbgemmI8Spmdm.cc
src/FbgemmSpConv.cc
@@ -92,7 +94,9 @@ endif()
#All the source files that either use avx2 instructions statically
set(FBGEMM_AVX2_SRCS
+ src/FbgemmBfloat16ConvertAvx2.cc
src/FbgemmFP16UKernelsAvx2.cc
+ src/FbgemmFloat16ConvertAvx2.cc
src/FbgemmI8Depthwise3DAvx2.cc
src/FbgemmI8Depthwise3x3Avx2.cc
src/FbgemmI8DepthwiseAvx2.cc
@@ -104,7 +108,9 @@ set(FBGEMM_AVX2_SRCS
#All the source files that use avx512 instructions statically
set(FBGEMM_AVX512_SRCS
+ src/FbgemmBfloat16ConvertAvx512.cc
src/FbgemmFP16UKernelsAvx512.cc
+ src/FbgemmFloat16ConvertAvx512.cc
src/UtilsAvx512.cc)
set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
@@ -133,7 +139,7 @@ set_target_properties(fbgemm_generic fbgemm_avx2 fbgemm_avx512 PROPERTIES
if (NOT MSVC)
target_compile_options(fbgemm_avx2 PRIVATE
- "-m64" "-mavx2" "-mfma" "-masm=intel" "-mf16c")
+ "-m64" "-mavx2" "-mf16c" "-mfma" "-masm=intel")
target_compile_options(fbgemm_avx512 PRIVATE
"-m64" "-mavx2" "-mfma" "-mavx512f" "-mavx512bw" "-mavx512dq"
"-mavx512vl" "-masm=intel" "-mf16c")
diff --git a/include/fbgemm/FbgemmConvert.h b/include/fbgemm/FbgemmConvert.h
new file mode 100644
index 0000000..27787d6
--- /dev/null
+++ b/include/fbgemm/FbgemmConvert.h
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <stdexcept>
+#include "fbgemm/Types.h"
+#include "fbgemm/Utils.h"
+
+namespace fbgemm {
+
+typedef uint16_t bfloat16;
+
+/**
+ * @ Transform all entries in a matrix from fp32 to bfloat16: reference
+ * implementation.
+ *
+ */
+FBGEMM_API void FloatToBfloat16_ref(const float* src, bfloat16* dst, int size);
+
+/**
+ * @ Transform all entries in a matrix from bfloat16 to fp32: reference
+ * implementation.
+ *
+ */
+FBGEMM_API void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, int size);
+
+/**
+ * @ Transform all entries in a matrix from fp32 to bfloat16: simd
+ * implementation.
+ *
+ */
+FBGEMM_API void FloatToBfloat16_simd(const float* src, bfloat16* dst, int size);
+
+/**
+ * @ Transform all entries in a matrix from bfloat16 to fp32: simd
+ * implementation.
+ *
+ */
+FBGEMM_API void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, int size);
+
+/**
+ * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
+ *
+ */
+void FloatToBfloat16_avx2(const float* src, bfloat16* dst, int size);
+
+/**
+ * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
+ *
+ */
+void FloatToBfloat16_avx512(const float* src, bfloat16* dst, int size);
+
+/**
+ * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
+ *
+ */
+void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, int size);
+
+/**
+ * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
+ *
+ */
+void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, int size);
+
+/**
+ * @ Transform all entries in a matrix from fp32 to float16: reference
+ * implementation.
+ *
+ */
+FBGEMM_API void FloatToFloat16_ref(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip = false);
+
+/**
+ * @ Transform all entries in a matrix from float16 to fp32: reference
+ * implementation.
+ *
+ */
+FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, int size);
+
+/**
+ * @ Transform all entries in a matrix from fp32 to float16: simd
+ * implementation.
+ *
+ */
+FBGEMM_API void FloatToFloat16_simd(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip = false);
+
+/**
+ * @ Transform all entries in a matrix from float16 to fp32: simd
+ * implementation.
+ *
+ */
+FBGEMM_API void Float16ToFloat_simd(const float16* src, float* dst, int size);
+
+/**
+ * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
+ *
+ */
+void FloatToFloat16_avx2(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip = false);
+
+/**
+ * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
+ *
+ */
+void FloatToFloat16_avx512(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip = false);
+
+/**
+ * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
+ *
+ */
+void Float16ToFloat_avx2(const float16* src, float* dst, int size);
+
+/**
+ * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
+ *
+ */
+void Float16ToFloat_avx512(const float16* src, float* dst, int size);
+
+}; // namespace fbgemm
diff --git a/src/FbgemmBfloat16Convert.cc b/src/FbgemmBfloat16Convert.cc
new file mode 100644
index 0000000..cf0361a
--- /dev/null
+++ b/src/FbgemmBfloat16Convert.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmConvert.h"
+
+#include "./RefImplementations.h"
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#ifdef USE_BLAS
+#include <cblas.h>
+#endif
+
+#include <cpuinfo.h>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double naive_malloc_time = 0.0;
+double naive_A_bf16_to_fp32_time = 0.0;
+double naive_B_bf16_to_fp32_time = 0.0;
+double naive_C_bf16_to_fp32_time = 0.0;
+double naive_computing_time = 0.0;
+double naive_C_fp32_to_bf16_time = 0.0;
+double naive_run_time = 0.0;
+#endif
+
+using namespace std;
+
+namespace fbgemm {
+
+void FloatToBfloat16_ref(const float* src, bfloat16* dst, int size) {
+ for (int i = 0; i < size; i++) {
+ // Add 2^15 and right shift 16 to do round-nearest
+ dst[i] = (*reinterpret_cast<const uint32_t*>(src + i) + (1 << 15)) >> 16;
+ }
+}
+
+void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, int size) {
+ for (int i = 0; i < size; i++) {
+ uint32_t val_fp32 =
+ static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(src)[i]) << 16;
+ reinterpret_cast<uint32_t*>(dst)[i] = val_fp32;
+ }
+}
+
+void FloatToBfloat16_simd(const float* src, bfloat16* dst, int size) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ FloatToBfloat16_avx512(src, dst, size);
+ } else if (fbgemmHasAvx2Support()) {
+ FloatToBfloat16_avx2(src, dst, size);
+ } else {
+ FloatToBfloat16_ref(src, dst, size);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, int size) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ Bfloat16ToFloat_avx512(src, dst, size);
+ } else if (fbgemmHasAvx2Support()) {
+ Bfloat16ToFloat_avx2(src, dst, size);
+ } else {
+ Bfloat16ToFloat_ref(src, dst, size);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmBfloat16ConvertAvx2.cc b/src/FbgemmBfloat16ConvertAvx2.cc
new file mode 100644
index 0000000..bf651b0
--- /dev/null
+++ b/src/FbgemmBfloat16ConvertAvx2.cc
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <immintrin.h>
+#include "fbgemm/FbgemmConvert.h"
+
+namespace fbgemm {
+
+namespace {
+
+inline __m256i QuantizeBfloat16Avx2(const __m256& x0, const __m256& x1) {
+ // Add 2^15 and right shift 16 to do round-nearest
+ __m256i y0 = _mm256_srli_epi32(
+ _mm256_add_epi32(
+ reinterpret_cast<__m256i>(x0), _mm256_set1_epi32(1 << 15)),
+ 16);
+ __m256i y1 = _mm256_srli_epi32(
+ _mm256_add_epi32(
+ reinterpret_cast<__m256i>(x1), _mm256_set1_epi32(1 << 15)),
+ 16);
+ // AVX2 doesn't have _mm256_cvtepi32_epi16 so we need this instruction
+ // sequence.
+ return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8);
+}
+
+inline void FloatToBfloat16KernelAvx2(const float* src, bfloat16* dst) {
+ // Two float m256i -> One bfloat16 m256i
+ const __m256 src_reg0 = _mm256_loadu_ps(src);
+ const __m256 src_reg1 = _mm256_loadu_ps(src + 8);
+ __m256i dst_reg = QuantizeBfloat16Avx2(src_reg0, src_reg1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg);
+}
+
+inline void Bfloat16ToFloatKernelAvx2(const bfloat16* src, float* dst) {
+ // One bfloat16 m128i -> One float m256i
+ const __m128i src_reg =
+ _mm_lddqu_si128(reinterpret_cast<const __m128i*>(src));
+ __m256i dst_reg_bf16 = _mm256_cvtepu16_epi32(src_reg);
+ __m256i dst_reg = _mm256_slli_epi32(dst_reg_bf16, 16);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg);
+}
+
+} // namespace
+
+void FloatToBfloat16_avx2(const float* src, bfloat16* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 8 * 2 <= size; i += 8 * 2) {
+ FloatToBfloat16KernelAvx2(src + i, dst + i);
+ }
+ FloatToBfloat16_ref(src + i, dst + i, size - i);
+}
+
+void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 8 <= size; i += 8) {
+ Bfloat16ToFloatKernelAvx2(src + i, dst + i);
+ }
+ Bfloat16ToFloat_ref(src + i, dst + i, size - i);
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmBfloat16ConvertAvx512.cc b/src/FbgemmBfloat16ConvertAvx512.cc
new file mode 100644
index 0000000..7d2975d
--- /dev/null
+++ b/src/FbgemmBfloat16ConvertAvx512.cc
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <immintrin.h>
+#include "fbgemm/FbgemmConvert.h"
+
+namespace fbgemm {
+
+namespace {
+
+inline __m256i QuantizeBfloat16Avx512(const __m512& x0) {
+ // Add 2^15 and right shift 16 to do round-nearest
+ __m512i y0 = _mm512_srli_epi32(
+ _mm512_add_epi32(
+ reinterpret_cast<__m512i>(x0), _mm512_set1_epi32(1 << 15)),
+ 16);
+ return _mm512_cvtepi32_epi16(y0);
+}
+
+inline void FloatToBfloat16KernelAvx512(const float* src, bfloat16* dst) {
+ // One float m512i -> One bfloat16 m256i
+ const __m512 src_reg0 = _mm512_loadu_ps(src);
+ __m256i dst_reg0 = QuantizeBfloat16Avx512(src_reg0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg0);
+}
+
+inline void Bfloat16ToFloatKernelAvx512(const bfloat16* src, float* dst) {
+ // One bfloat16 m256i -> One float m512i
+ const __m256i src_reg =
+ _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(src));
+ __m512i dst_reg_bf16 = _mm512_cvtepu16_epi32(src_reg);
+ __m512i dst_reg = _mm512_slli_epi32(dst_reg_bf16, 16);
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), dst_reg);
+}
+
+} // namespace
+
+void FloatToBfloat16_avx512(const float* src, bfloat16* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 16 <= size; i += 16) {
+ FloatToBfloat16KernelAvx512(src + i, dst + i);
+ }
+ FloatToBfloat16_avx2(src + i, dst + i, size - i);
+}
+
+void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 16 <= size; i += 16) {
+ Bfloat16ToFloatKernelAvx512(src + i, dst + i);
+ }
+ Bfloat16ToFloat_avx2(src + i, dst + i, size - i);
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmFloat16Convert.cc b/src/FbgemmFloat16Convert.cc
new file mode 100644
index 0000000..3bd11b5
--- /dev/null
+++ b/src/FbgemmFloat16Convert.cc
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmConvert.h"
+
+#include "./RefImplementations.h"
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#ifdef USE_BLAS
+#include <cblas.h>
+#endif
+
+#include <cpuinfo.h>
+#include <memory>
+#include <utility>
+#include <vector>
+
+using namespace std;
+
+namespace fbgemm {
+
+void FloatToFloat16_ref(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ constexpr float FP16_MAX = 65504.f;
+ if (do_clip) {
+ for (int i = 0; i < size; i++) {
+ float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX));
+ dst[i] = cpu_float2half_rn(cur_src);
+ }
+ } else {
+ for (int i = 0; i < size; i++) {
+ dst[i] = cpu_float2half_rn(src[i]);
+ }
+ }
+}
+
+void Float16ToFloat_ref(const float16* src, float* dst, int size) {
+ for (int i = 0; i < size; i++) {
+ dst[i] = cpu_half2float(src[i]);
+ }
+}
+
+void FloatToFloat16_simd(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ FloatToFloat16_avx512(src, dst, size, do_clip);
+ } else if (fbgemmHasAvx2Support()) {
+ FloatToFloat16_avx2(src, dst, size, do_clip);
+ } else {
+ FloatToFloat16_ref(src, dst, size, do_clip);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+void Float16ToFloat_simd(const float16* src, float* dst, int size) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ Float16ToFloat_avx512(src, dst, size);
+ } else if (fbgemmHasAvx2Support()) {
+ Float16ToFloat_avx2(src, dst, size);
+ } else {
+ Float16ToFloat_ref(src, dst, size);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmFloat16ConvertAvx2.cc b/src/FbgemmFloat16ConvertAvx2.cc
new file mode 100644
index 0000000..51290b0
--- /dev/null
+++ b/src/FbgemmFloat16ConvertAvx2.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <immintrin.h>
+#include "fbgemm/FbgemmConvert.h"
+
+namespace fbgemm {
+
+namespace {
+
+inline void FloatToFloat16KernelAvx2(const float* src, float16* dst) {
+ __m256 float_vector = _mm256_loadu_ps(src);
+ __m128i half_vector = _mm256_cvtps_ph(
+ float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*)dst, half_vector);
+}
+
+inline void FloatToFloat16KernelAvx2WithClip(const float* src, float16* dst) {
+ constexpr float FP16_MAX = 65504.f;
+ __m256 neg_fp16_max_vector = _mm256_set1_ps(-FP16_MAX);
+ __m256 pos_fp16_max_vector = _mm256_set1_ps(FP16_MAX);
+
+ __m256 float_vector = _mm256_loadu_ps(src);
+
+ // Do the clipping.
+ float_vector = _mm256_max_ps(
+ neg_fp16_max_vector, _mm256_min_ps(float_vector, pos_fp16_max_vector));
+
+ __m128i half_vector = _mm256_cvtps_ph(
+ float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*)dst, half_vector);
+}
+
+inline void Float16ToFloatKernelAvx2(const float16* src, float* dst) {
+ __m128i half_vector = _mm_loadu_si128((__m128i*)src);
+ __m256 float_vector = _mm256_cvtph_ps(half_vector);
+ _mm256_storeu_ps(dst, float_vector);
+}
+
+} // namespace
+
+void FloatToFloat16_avx2(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ if (do_clip) {
+ int i = 0;
+ for (i = 0; i + 8 <= size; i += 8) {
+ FloatToFloat16KernelAvx2WithClip(src + i, dst + i);
+ }
+ FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
+ } else {
+ int i = 0;
+ for (i = 0; i + 8 <= size; i += 8) {
+ FloatToFloat16KernelAvx2(src + i, dst + i);
+ }
+ FloatToFloat16_ref(src + i, dst + i, size - i);
+ }
+}
+
+void Float16ToFloat_avx2(const float16* src, float* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 8 <= size; i += 8) {
+ Float16ToFloatKernelAvx2(src + i, dst + i);
+ }
+ Float16ToFloat_ref(src + i, dst + i, size - i);
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmFloat16ConvertAvx512.cc b/src/FbgemmFloat16ConvertAvx512.cc
new file mode 100644
index 0000000..953dd0f
--- /dev/null
+++ b/src/FbgemmFloat16ConvertAvx512.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <immintrin.h>
+#include "fbgemm/FbgemmConvert.h"
+
+namespace fbgemm {
+
+namespace {
+
+inline void FloatToFloat16KernelAvx512(const float* src, float16* dst) {
+ __m512 float_vector = _mm512_loadu_ps(src);
+ __m256i half_vector = _mm512_cvtps_ph(
+ float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm256_storeu_si256((__m256i*)dst, half_vector);
+}
+
+inline void FloatToFloat16KernelAvx512WithClip(const float* src, float16* dst) {
+ constexpr float FP16_MAX = 65504.f;
+ __m512 neg_fp16_max_vector = _mm512_set1_ps(-FP16_MAX);
+ __m512 pos_fp16_max_vector = _mm512_set1_ps(FP16_MAX);
+
+ __m512 float_vector = _mm512_loadu_ps(src);
+
+ // Do the clipping.
+ float_vector = _mm512_max_ps(
+ neg_fp16_max_vector, _mm512_min_ps(float_vector, pos_fp16_max_vector));
+
+ __m256i half_vector = _mm512_cvtps_ph(
+ float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm256_storeu_si256((__m256i*)dst, half_vector);
+}
+
+inline void Float16ToFloatKernelAvx512(const float16* src, float* dst) {
+ __m256i half_vector = _mm256_loadu_si256((__m256i*)src);
+ __m512 float_vector = _mm512_cvtph_ps(half_vector);
+ _mm512_storeu_ps(dst, float_vector);
+}
+
+} // namespace
+
+void FloatToFloat16_avx512(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ if (do_clip) {
+ int i = 0;
+ for (i = 0; i + 16 <= size; i += 16) {
+ FloatToFloat16KernelAvx512WithClip(src + i, dst + i);
+ }
+ FloatToFloat16_avx2(src + i, dst + i, size - i, do_clip);
+ } else {
+ int i = 0;
+ for (i = 0; i + 16 <= size; i += 16) {
+ FloatToFloat16KernelAvx512(src + i, dst + i);
+ }
+ FloatToFloat16_avx2(src + i, dst + i, size - i);
+ }
+}
+
+void Float16ToFloat_avx512(const float16* src, float* dst, int size) {
+ int i = 0;
+ for (i = 0; i + 16 <= size; i += 16) {
+ Float16ToFloatKernelAvx512(src + i, dst + i);
+ }
+ Float16ToFloat_avx2(src + i, dst + i, size - i);
+}
+
+} // namespace fbgemm
diff --git a/test/Bfloat16ConvertTest.cc b/test/Bfloat16ConvertTest.cc
new file mode 100644
index 0000000..be6f9be
--- /dev/null
+++ b/test/Bfloat16ConvertTest.cc
@@ -0,0 +1,83 @@
+#include <gtest/gtest.h>
+#include <cmath>
+#include <random>
+
+#include "bench/BenchUtils.h"
+#include "fbgemm/FbgemmConvert.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+TEST(FBGemmBfloat16Test, Conversion) {
+ float a[100]; // fp32 type
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ bfloat16 b[100]; // bfloat16 type
+ float c[100]; // fp32 type
+ FloatToBfloat16_ref(a, b, 100);
+ Bfloat16ToFloat_ref(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^7) since bfloat16
+ // has 7 bits mantissa.
+ EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
+ }
+}
+
+TEST(FBGemmBfloat16Test, Conversion_simd) {
+ float a[100]; // fp32 type
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ bfloat16 b[100]; // bfloat16 type
+ float c[100]; // fp32 type
+ FloatToBfloat16_simd(a, b, 100);
+ Bfloat16ToFloat_simd(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^7) since bfloat16
+ // has 7 bits mantissa.
+ EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128)
+ << "Conversion results differ at (" << i << " ). ref: " << a[i]
+ << " conversion: " << c[i];
+ }
+}
+
+TEST(FBGemmBfloat16Test, Conversion_simd2) {
+ vector<vector<int>> shapes;
+ random_device r;
+ default_random_engine generator(r());
+ uniform_int_distribution<int> dm(1, 256);
+ uniform_int_distribution<int> dn(1, 1024);
+
+ for (int i = 0; i < 10; i++) {
+ int m = dm(generator);
+ int n = dn(generator);
+ shapes.push_back({m, n});
+ }
+
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+
+ cerr << "m = " << m << " n = " << n << endl;
+ aligned_vector<float> A_fp32_ref(m * n); // fp32 type
+ aligned_vector<bfloat16> A_bfloat16(m * n); // bfloat16 type
+ aligned_vector<float> A_fp32_final(m * n); // fp32 type
+ // randFill(A_fp32_ref, 0.0f, 4.0f);
+ for (int i = 0; i < m * n; ++i) {
+ A_fp32_ref[i] = i + 1.25;
+ }
+
+ FloatToBfloat16_simd(A_fp32_ref.data(), A_bfloat16.data(), m * n);
+ Bfloat16ToFloat_simd(A_bfloat16.data(), A_fp32_final.data(), m * n);
+ for (int i = 0; i < m * n; ++i) {
+ // The relative error should be less than 1/(2^7) since bfloat16
+ // has 7 bits mantissa.
+ // printf( "A_fp32_final[%d]: %f; A_fp32_ref[%d]: %f\n", i,
+ // A_fp32_final[i], i, A_fp32_ref[i]);
+ EXPECT_LE(
+ fabs(A_fp32_final[i] - A_fp32_ref[i]) / A_fp32_ref[i], 1.0 / 128);
+ }
+ }
+}
diff --git a/test/Float16ConvertTest.cc b/test/Float16ConvertTest.cc
new file mode 100644
index 0000000..7a797e1
--- /dev/null
+++ b/test/Float16ConvertTest.cc
@@ -0,0 +1,119 @@
+#include <gtest/gtest.h>
+#include <cmath>
+#include <random>
+
+#include "bench/BenchUtils.h"
+#include "fbgemm/FbgemmConvert.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+namespace {
+class FBGemmFloat16Test : public testing::TestWithParam<bool> {};
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ FBGemmFloat16Test,
+ ::testing::Bool());
+
+TEST_P(FBGemmFloat16Test, Conversion) {
+ bool do_clip = GetParam();
+ constexpr float FP16_MAX = 65504.f;
+
+ float a[100]; // fp32 type
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ if (do_clip) {
+ a[3] += 1024 * FP16_MAX;
+ }
+ float16 b[100]; // float16 type
+ float c[100]; // fp32 type
+ FloatToFloat16_ref(a, b, 100, do_clip);
+ Float16ToFloat_ref(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^10) since float16
+ // has 10 bits mantissa.
+ float expected = a[i];
+ if (do_clip) {
+ expected = std::max(-FP16_MAX, std::min(expected, FP16_MAX));
+ }
+ EXPECT_LE(fabs(expected - c[i]) / expected, 1.0 / 1024);
+ }
+}
+
+TEST_P(FBGemmFloat16Test, Conversion_simd) {
+ bool do_clip = GetParam();
+ constexpr float FP16_MAX = 65504.f;
+
+ float a[100]; // fp32 type
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ if (do_clip) {
+ a[3] += 1024 * FP16_MAX;
+ }
+ float16 b[100]; // float16 type
+ float c[100]; // fp32 type
+ FloatToFloat16_simd(a, b, 100, do_clip);
+ Float16ToFloat_simd(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^10) since float16
+ // has 10 bits mantissa.
+ float expected = a[i];
+ if (do_clip) {
+ expected = std::max(-FP16_MAX, std::min(expected, FP16_MAX));
+ }
+ EXPECT_LE(fabs(expected - c[i]) / expected, 1.0 / 1024);
+ }
+}
+
+TEST_P(FBGemmFloat16Test, Conversion_simd2) {
+ bool do_clip = GetParam();
+ constexpr float FP16_MAX = 65504.f;
+
+ vector<vector<int>> shapes;
+ random_device r;
+ default_random_engine generator(r());
+ uniform_int_distribution<int> dm(1, 256);
+ uniform_int_distribution<int> dn(1, 1024);
+
+ for (int i = 0; i < 10; i++) {
+ int m = dm(generator);
+ int n = dn(generator);
+ shapes.push_back({m, n});
+ }
+
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+
+ cerr << "m = " << m << " n = " << n << endl;
+ aligned_vector<float> A_fp32_ref(m * n); // fp32 type
+ aligned_vector<float16> A_float16(m * n); // float16 type
+ aligned_vector<float> A_fp32_final(m * n); // fp32 type
+ // randFill(A_fp32_ref, 0.0f, 4.0f);
+ for (int i = 0; i < m * n; ++i) {
+ A_fp32_ref[i] = (i % 10000) + 1.25;
+ }
+ if (do_clip) {
+ A_fp32_ref[0] += 1024 * FP16_MAX;
+ }
+
+ FloatToFloat16_simd(A_fp32_ref.data(), A_float16.data(), m * n, do_clip);
+ Float16ToFloat_simd(A_float16.data(), A_fp32_final.data(), m * n);
+ for (int i = 0; i < m * n; ++i) {
+ // The relative error should be less than 1/(2^10) since float16
+ // has 10 bits mantissa.
+ // printf( "A_fp32_final[%d]: %f; A_fp32_ref[%d]: %f\n", i,
+ // A_fp32_final[i], i, A_fp32_ref[i]);
+ float expected = A_fp32_ref[i];
+ if (do_clip) {
+ expected = std::max(-FP16_MAX, std::min(expected, FP16_MAX));
+ }
+ EXPECT_LE(fabs(expected - A_fp32_final[i]) / expected, 1.0 / 1024);
+ }
+ }
+}