Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/Utils_avx512.cc')
-rw-r--r--src/Utils_avx512.cc243
1 files changed, 243 insertions, 0 deletions
diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc
new file mode 100644
index 0000000..b6bf413
--- /dev/null
+++ b/src/Utils_avx512.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "fbgemm/Utils.h"
+
+#include <immintrin.h>
+
+namespace fbgemm2 {
+
+inline void transpose_kernel_16x16_avx512(
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // load from src to registers
+ // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
+ // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
+ // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
+ // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
+ // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
+ // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
+ // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
+ // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
+ // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
+ // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
+ // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
+ // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
+ // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
+ // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
+ // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
+ // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
+ __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
+ __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
+ __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
+ __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
+ __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
+ __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
+ __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
+ __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
+ __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
+ __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
+ __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
+ __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
+ __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
+ __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
+ __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
+ __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
+
+ __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
+ // unpacking and interleaving 32-bit elements
+ // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13
+ // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15
+ // c0 d0 c1 d1 ...
+ // c2 d2 c3 d3 ...
+ // e0 f0 e1 f1 ...
+ // e2 f2 e3 f3 ...
+ // g0 h0 g1 h1 ...
+ // g2 h2 g3 h3 ...
+ // i0 ...
+ // i2 ...
+ // k0 ...
+ // k2 ...
+ // m0 ...
+ // m2 ...
+ // o0 ...
+ // o1 ...
+ ta = _mm512_unpacklo_ps(a, b);
+ tb = _mm512_unpackhi_ps(a, b);
+ tc = _mm512_unpacklo_ps(c, d);
+ td = _mm512_unpackhi_ps(c, d);
+ te = _mm512_unpacklo_ps(e, f);
+ tf = _mm512_unpackhi_ps(e, f);
+ tg = _mm512_unpacklo_ps(g, h);
+ th = _mm512_unpackhi_ps(g, h);
+ ti = _mm512_unpacklo_ps(i, j);
+ tj = _mm512_unpackhi_ps(i, j);
+ tk = _mm512_unpacklo_ps(k, l);
+ tl = _mm512_unpackhi_ps(k, l);
+ tm = _mm512_unpacklo_ps(m, n);
+ tn = _mm512_unpackhi_ps(m, n);
+ to = _mm512_unpacklo_ps(o, p);
+ tq = _mm512_unpackhi_ps(o, p);
+
+ // unpacking and interleaving 64-bit elements
+ // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12
+ // e1 f1 g1 h1 ...
+ // e2 f2 g2 h2 ...
+ // e3 f3 g3 h3 ...
+ // i0 j0 k0 l0 ...
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // m0 n0 o0 p0 ...
+ // m1 n1 o1 p1 ...
+ // m2 n2 o2 p2 ...
+ // m3 n3 o3 p3 ...
+ a = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ b = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ c = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ d = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ e = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ f = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ g = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ h = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ i = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ j = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ k = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ l = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ m = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ n = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ o = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+ p = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // a4 b4 c4 d4 ...
+ // a5 b5 c5 d5 ...
+ // a6 b6 c6 d6 ...
+ // a7 b7 c7 d7 ...
+ // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // i4 j4 k4 l4 ...
+ // i5 j5 k5 l5 ...
+ // i6 j6 k6 l6 ...
+ // i7 j7 k7 l7 ...
+ ta = _mm512_shuffle_f32x4(a, e, 0x88);
+ tb = _mm512_shuffle_f32x4(b, f, 0x88);
+ tc = _mm512_shuffle_f32x4(c, g, 0x88);
+ td = _mm512_shuffle_f32x4(d, h, 0x88);
+ te = _mm512_shuffle_f32x4(a, e, 0xdd);
+ tf = _mm512_shuffle_f32x4(b, f, 0xdd);
+ tg = _mm512_shuffle_f32x4(c, g, 0xdd);
+ th = _mm512_shuffle_f32x4(d, h, 0xdd);
+ ti = _mm512_shuffle_f32x4(i, m, 0x88);
+ tj = _mm512_shuffle_f32x4(j, n, 0x88);
+ tk = _mm512_shuffle_f32x4(k, o, 0x88);
+ tl = _mm512_shuffle_f32x4(l, p, 0x88);
+ tm = _mm512_shuffle_f32x4(i, m, 0xdd);
+ tn = _mm512_shuffle_f32x4(j, n, 0xdd);
+ to = _mm512_shuffle_f32x4(k, o, 0xdd);
+ tq = _mm512_shuffle_f32x4(l, p, 0xdd);
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 ... o0
+ // a1 b1 c1 d1 ... o1
+ // a2 b2 c2 d2 ... o2
+ // a3 b3 c3 d3 ... o3
+ // a4 ...
+ // a5 ...
+ // a6 ...
+ // a7 ...
+ // a8 ...
+ // a9 ...
+ // a10 ...
+ // a11 ...
+ // a12 ...
+ // a13 ...
+ // a14 ...
+ // a15 b15 c15 d15 ... o15
+ a = _mm512_shuffle_f32x4(ta, ti, 0x88);
+ b = _mm512_shuffle_f32x4(tb, tj, 0x88);
+ c = _mm512_shuffle_f32x4(tc, tk, 0x88);
+ d = _mm512_shuffle_f32x4(td, tl, 0x88);
+ e = _mm512_shuffle_f32x4(te, tm, 0x88);
+ f = _mm512_shuffle_f32x4(tf, tn, 0x88);
+ g = _mm512_shuffle_f32x4(tg, to, 0x88);
+ h = _mm512_shuffle_f32x4(th, tq, 0x88);
+ i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
+ j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
+ k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
+ l = _mm512_shuffle_f32x4(td, tl, 0xdd);
+ m = _mm512_shuffle_f32x4(te, tm, 0xdd);
+ n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
+ o = _mm512_shuffle_f32x4(tg, to, 0xdd);
+ p = _mm512_shuffle_f32x4(th, tq, 0xdd);
+
+ // store from registers to dst
+ _mm512_storeu_ps(&dst[0 * ld_dst], a);
+ _mm512_storeu_ps(&dst[1 * ld_dst], b);
+ _mm512_storeu_ps(&dst[2 * ld_dst], c);
+ _mm512_storeu_ps(&dst[3 * ld_dst], d);
+ _mm512_storeu_ps(&dst[4 * ld_dst], e);
+ _mm512_storeu_ps(&dst[5 * ld_dst], f);
+ _mm512_storeu_ps(&dst[6 * ld_dst], g);
+ _mm512_storeu_ps(&dst[7 * ld_dst], h);
+ _mm512_storeu_ps(&dst[8 * ld_dst], i);
+ _mm512_storeu_ps(&dst[9 * ld_dst], j);
+ _mm512_storeu_ps(&dst[10 * ld_dst], k);
+ _mm512_storeu_ps(&dst[11 * ld_dst], l);
+ _mm512_storeu_ps(&dst[12 * ld_dst], m);
+ _mm512_storeu_ps(&dst[13 * ld_dst], n);
+ _mm512_storeu_ps(&dst[14 * ld_dst], o);
+ _mm512_storeu_ps(&dst[15 * ld_dst], p);
+}
+
+void transpose_16x16(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 16 <= M; ib += 16) {
+ for (jb = 0; jb + 16 <= N; jb += 16) {
+ transpose_kernel_16x16_avx512(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+} // namespace fbgemm2