Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleks Zi <zlateski@fb.com>2019-09-16 21:03:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-16 21:07:52 +0300
commit96f2b9db2ea2972b6b8c04ed165a1854220a5e0b (patch)
treed833b0c81ff128543c294d977ebc86351ef50f0c
parent2f1477dfee9465c1e2dbdf21722970b3fa1baf86 (diff)
Small refactoring of FBGEMM GenerateKernel class
Summary: Removed unnecessary member variables, using sstream instead of strings. Reviewed By: dskhudia Differential Revision: D17134969 fbshipit-source-id: 147d0b39cde9edf5fb70762558e90dced5ba0ab1
-rw-r--r--src/GenerateKernel.h73
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc16
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc20
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc20
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc19
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc18
6 files changed, 77 insertions, 89 deletions
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index 66f404b..c0fece4 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -9,6 +9,7 @@
#include <cpuinfo.h>
#include <map>
#include <mutex>
+#include <sstream>
#include <string>
#include <tuple>
#include "CodeCache.h"
@@ -42,35 +43,7 @@ class CodeGenBase {
* @brief Constructor for initializing AVX2/AVX512 registers.
*/
CodeGenBase(const BlockingFactors* params = nullptr)
- : blocking_params(params),
- CRegs_avx2_{x86::ymm0,
- x86::ymm1,
- x86::ymm2,
- x86::ymm3,
- x86::ymm4,
- x86::ymm5,
- x86::ymm6,
- x86::ymm7,
- x86::ymm8,
- x86::ymm9,
- x86::ymm10,
- x86::ymm11},
- CRegs_avx512_{
- x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
- x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
- x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
- x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
- x86::zmm25, x86::zmm26, x86::zmm27,
- },
- AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3,
- x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7,
- x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11,
- x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15,
- x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23,
- x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27,
- x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} {
+ : blocking_params(params) {
// vector width in bits
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
@@ -145,7 +118,7 @@ class CodeGenBase {
* (debug-only)
*/
template <inst_set_t instSet>
- std::string getCodeLoggingFile(
+ static std::string getCodeLoggingFile(
bool accum,
int mc,
int nc,
@@ -154,39 +127,35 @@ class CodeGenBase {
int MR,
int NR,
int NR_MIN) {
- std::string fileName = "gemm_";
+ std::ostringstream oss;
+ oss << "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
- fileName += "acc16_";
+ oss << "acc16_";
} else if (std::is_same<accT, std::int32_t>::value) {
- fileName += "acc32_";
+ oss << "acc32_";
} else {
- fileName += "unknown_";
+ oss << "unknown_";
}
- fileName += "accum-" + std::to_string(accum);
- fileName += "_MC-" + std::to_string(mc);
- fileName += "_NC-" + std::to_string(nc);
- fileName += "_NCB-" + std::to_string(NCB);
- fileName += "_NCB-" + std::to_string(KCB);
- fileName += "_MR-" + std::to_string(MR);
- fileName += "_NR-" + std::to_string(NR);
- fileName += "_NR_MIN-" + std::to_string(NR_MIN);
+ oss << "accum-" + std::to_string(accum)
+ << "_MC-" + std::to_string(mc)
+ << "_NC-" + std::to_string(nc)
+ << "_NCB-" + std::to_string(NCB)
+ << "_NCB-" + std::to_string(KCB)
+ << "_MR-" + std::to_string(MR)
+ << "_NR-" + std::to_string(NR)
+ << "_NR_MIN-" + std::to_string(NR_MIN);
if (instSet == inst_set_t::avx512_vnni) {
- fileName += "_avx512vnni";
+ oss << "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
- fileName += "_avx512";
+ oss << "_avx512";
} else if (instSet == inst_set_t::avx2) {
- fileName += "_avx2";
+ oss << "_avx2";
}
- fileName += ".txt";
- return fileName;
+ oss << ".txt";
+ return oss.str();
}
private:
- x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- x86::Zmm
- CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
-
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index f83012b..cbd5877 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -54,6 +55,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
x86::Ymm tmpReg = x86::ymm14;
+ using CRegs = x86::Ymm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
@@ -62,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpmaddubsw(
tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vpaddsw(
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -93,12 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
x86::Xmm extractDest128 = x86::xmm15;
x86::Ymm extractDest256 = x86::ymm15;
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(
- extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
+ extractDest128, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest256, extractDest128);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index b67d8e8..512c8ba 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -57,20 +58,22 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
a->vmovups(
- AllRegs_avx512_[27 - j],
+ x86::Zmm(27 - j),
x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
}
+ using CRegs = x86::Zmm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
+ a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j));
a->vpaddsw(
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -100,12 +103,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
x86::Ymm extractDest256 = x86::ymm31;
x86::Zmm extractDest512 = x86::zmm31;
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti32x8(
- extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
+ extractDest256, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest512, extractDest256);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index d39f153..a0fe26c 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -61,6 +62,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// temporary register
x86::Ymm res1 = x86::ymm14;
+ using CRegs = x86::Ymm;
+
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -71,9 +74,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -94,6 +97,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -101,13 +105,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)));
}
a->vmovups(
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)),
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
}
}
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 817b336..ecd7769 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -61,6 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// temporary register
x86::Zmm res1 = x86::zmm28;
+ using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -71,9 +73,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -94,6 +96,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -103,13 +106,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
}
a->vmovups(
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
}
}
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
index d6cb7c2..1d23e90 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -55,6 +56,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// used for matrix B
x86::Zmm BReg = x86::zmm30;
+ using CRegs = x86::Zmm;
+
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -62,7 +65,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
- a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
+ a->vpdpbusd(CRegs(i * leadingDimCReg + j), AReg, BReg);
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -83,6 +86,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -92,13 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
}
a->vmovups(
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
}
}