Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2020-05-14 20:59:21 +0300
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2020-05-14 21:01:18 +0300
commit46981b81867571cadfbbdadb019ad306eac29f51 (patch)
tree6831171d34584556a530fe59dcc1870aef3bace1
parent7ed5f9f16cffd1c17eae6b6d5f5f07f06cc83565 (diff)
Minor improvements in GEMM Kernels (#368)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/368 1) Replace vxorps with vpxor. vpxor can execute on multiple ports. Also zero-out xmm and ymm/zmm are zeroed out implicitly. 2) Use integer version of vmov (On some architectures there is some penalty for fp to integer vector lane changes) Reference: https://stackoverflow.com/questions/33666617/what-is-the-best-way-to-set-a-register-to-zero-in-x86-assembly-xor-mov-or-and No significant change in performance. Before: ``` M, N, K, Type, GOPS 64, 800, 320, MKL_fp32, 46.3 64, 800, 320, FBGEMM_i8_acc32, 83.6 64, 800, 320, FBGEMM_i8_acc16, 79.2 64, 768, 512, MKL_fp32, 46.9 64, 768, 512, FBGEMM_i8_acc32, 88.1 64, 768, 512, FBGEMM_i8_acc16, 89.2 16, 256, 512, MKL_fp32, 27.6 16, 256, 512, FBGEMM_i8_acc32, 43.5 16, 256, 512, FBGEMM_i8_acc16, 54.1 128, 128, 128, MKL_fp32, 30.1 128, 128, 128, FBGEMM_i8_acc32, 42.1 128, 128, 128, FBGEMM_i8_acc16, 40.6 256, 512, 256, MKL_fp32, 44.8 256, 512, 256, FBGEMM_i8_acc32, 91.7 256, 512, 256, FBGEMM_i8_acc16, 91.1 1024, 1024, 1024, MKL_fp32, 48.8 1024, 1024, 1024, FBGEMM_i8_acc32, 97.0 1024, 1024, 1024, FBGEMM_i8_acc16, 97.6 ``` After: ``` M, N, K, Type, GOPS 64, 800, 320, MKL_fp32, 46.2 64, 800, 320, FBGEMM_i8_acc32, 83.5 64, 800, 320, FBGEMM_i8_acc16, 80.8 64, 768, 512, MKL_fp32, 47.2 64, 768, 512, FBGEMM_i8_acc32, 88.5 64, 768, 512, FBGEMM_i8_acc16, 87.3 16, 256, 512, MKL_fp32, 26.0 16, 256, 512, FBGEMM_i8_acc32, 44.0 16, 256, 512, FBGEMM_i8_acc16, 54.5 128, 128, 128, MKL_fp32, 29.6 128, 128, 128, FBGEMM_i8_acc32, 42.2 128, 128, 128, FBGEMM_i8_acc16, 38.5 256, 512, 256, MKL_fp32, 44.3 256, 512, 256, FBGEMM_i8_acc32, 91.1 256, 512, 256, FBGEMM_i8_acc16, 91.0 1024, 1024, 1024, MKL_fp32, 48.7 1024, 1024, 1024, FBGEMM_i8_acc32, 96.6 1024, 1024, 1024, FBGEMM_i8_acc16, 96.5 ``` Reviewed By: jspark1105 Differential Revision: D21433384 fbshipit-source-id: d0abd56f454293e159d3fda9d94bc84e011060c8
-rw-r--r--src/GenerateI8Depthwise.cc6
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc6
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc6
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc8
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc8
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc3
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc9
7 files changed, 29 insertions, 17 deletions
diff --git a/src/GenerateI8Depthwise.cc b/src/GenerateI8Depthwise.cc
index c63c436..a9509b7 100644
--- a/src/GenerateI8Depthwise.cc
+++ b/src/GenerateI8Depthwise.cc
@@ -351,7 +351,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
}
x86::Ymm zero(vreg_id);
if (need_zero && (!recompute_zero || !has_pad)) {
- e->vxorps(zero, zero, zero);
+ e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
}
// Assign scalar registers
@@ -433,7 +433,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
if (i % 4 == 3 || i == K - 1) {
if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) {
if (recompute_zero && has_pad) {
- e->vxorps(zero, zero, zero);
+ e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
}
}
@@ -465,7 +465,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
asmjit::Imm(r < 2 ? 0x20 : 0x31));
}
for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
- e->vmovaps(c[r], a[r]);
+ e->vmovdqa(c[r], a[r]);
}
}
}
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index b05d85a..fd4073a 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) {
- using CRegs = x86::Ymm;
+ using CRegs = x86::Xmm;
+ // Take advantage of implicit zeroing out
+ // i.e., zero out xmm and ymm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
- a->vxorps(
+ a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index a456fa6..22e7a30 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) {
- using CRegs = x86::Zmm;
+ using CRegs = x86::Xmm;
+ // Take advantage of implicit zeroing out
+ // i.e., zero out xmm and zmm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
- a->vxorps(
+ a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 4a8e759..0b14aab 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) {
- using CRegs = x86::Ymm;
+ using CRegs = x86::Xmm;
+ // Take advantage of implicit zeroing out
+ // i.e., zero out xmm and ymm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
- a->vxorps(
+ a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
@@ -61,7 +63,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
for (int j = 0; j < colRegs; ++j) {
// load B
- a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vmovdqa(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
// load A, broadcast and fmas
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index dc1437d..c9dfcc9 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) {
- using CRegs = x86::Zmm;
+ using CRegs = x86::Xmm;
+ // Take advantage of implicit zeroing out
+ // i.e., zero out xmm and zmm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
- a->vxorps(
+ a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
@@ -60,7 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
// load B
- a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vmovdqa32(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
// load A, broadcast and fmas
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
index 8db53f3..0d9f295 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -45,7 +45,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
assert(colRegs * (rowRegs + 1) <= 31);
for (int j = 0; j < colRegs; ++j) {
- a->vmovaps(x86::Zmm(30-j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vmovdqa32(
+ x86::Zmm(30 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
}
for (int i = 0; i < rowRegs; i++) {
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 95377f0..16ee1c6 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -308,11 +308,13 @@ template <int SPATIAL_DIM>
void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::initResultRegs(
x86::Emitter* a) {
if (kLoopIters_ > 0) {
+ // Take advantage of implicit zeroing out
+ // i.e., zero out xmm and ymm will be zeroed out too
for (int k = 0; k < kLoopIters_; ++k) {
- a->vxorps(x86::Ymm(9 - k), x86::Ymm(9 - k), x86::Ymm(9 - k));
+ a->vpxor(x86::Xmm(9 - k), x86::Xmm(9 - k), x86::Xmm(9 - k));
}
} else {
- a->vxorps(x86::Ymm(9), x86::Ymm(9), x86::Ymm(9));
+ a->vpxor(x86::Xmm(9), x86::Xmm(9), x86::Xmm(9));
}
}
@@ -557,7 +559,8 @@ void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::genForSingleOutput(
// row offset
if (this->needRowOffset_) {
- a->vxorps(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_);
+ a->vpxor(
+ rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm());
}
bool isWidthMiddle = !isLeft && !isRight;