Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmFloat16Convert.cc')
-rw-r--r--src/FbgemmFloat16Convert.cc88
1 files changed, 88 insertions, 0 deletions
diff --git a/src/FbgemmFloat16Convert.cc b/src/FbgemmFloat16Convert.cc
new file mode 100644
index 0000000..3bd11b5
--- /dev/null
+++ b/src/FbgemmFloat16Convert.cc
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmConvert.h"
+
+#include "./RefImplementations.h"
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#ifdef USE_BLAS
+#include <cblas.h>
+#endif
+
+#include <cpuinfo.h>
+#include <memory>
+#include <utility>
+#include <vector>
+
+using namespace std;
+
+namespace fbgemm {
+
+void FloatToFloat16_ref(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ constexpr float FP16_MAX = 65504.f;
+ if (do_clip) {
+ for (int i = 0; i < size; i++) {
+ float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX));
+ dst[i] = cpu_float2half_rn(cur_src);
+ }
+ } else {
+ for (int i = 0; i < size; i++) {
+ dst[i] = cpu_float2half_rn(src[i]);
+ }
+ }
+}
+
+void Float16ToFloat_ref(const float16* src, float* dst, int size) {
+ for (int i = 0; i < size; i++) {
+ dst[i] = cpu_half2float(src[i]);
+ }
+}
+
+void FloatToFloat16_simd(
+ const float* src,
+ float16* dst,
+ int size,
+ bool do_clip) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ FloatToFloat16_avx512(src, dst, size, do_clip);
+ } else if (fbgemmHasAvx2Support()) {
+ FloatToFloat16_avx2(src, dst, size, do_clip);
+ } else {
+ FloatToFloat16_ref(src, dst, size, do_clip);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+void Float16ToFloat_simd(const float16* src, float* dst, int size) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ Float16ToFloat_avx512(src, dst, size);
+ } else if (fbgemmHasAvx2Support()) {
+ Float16ToFloat_avx2(src, dst, size);
+ } else {
+ Float16ToFloat_ref(src, dst, size);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+} // namespace fbgemm