diff options
author | Daya Khudia <dskhudia@fb.com> | 2020-05-14 20:59:21 +0300 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2020-05-14 21:01:18 +0300 |
commit | 46981b81867571cadfbbdadb019ad306eac29f51 (patch) | |
tree | 6831171d34584556a530fe59dcc1870aef3bace1 | |
parent | 7ed5f9f16cffd1c17eae6b6d5f5f07f06cc83565 (diff) |
Minor improvements in GEMM Kernels (#368)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/368
1) Replace vxorps with vpxor. vpxor can execute on multiple ports. Also zero-out xmm and ymm/zmm are zeroed out implicitly.
2) Use integer version of vmov (On some architectures there is some penalty for fp to integer vector lane changes)
Reference: https://stackoverflow.com/questions/33666617/what-is-the-best-way-to-set-a-register-to-zero-in-x86-assembly-xor-mov-or-and
No significant change in performance.
Before:
```
M, N, K, Type, GOPS
64, 800, 320, MKL_fp32, 46.3
64, 800, 320, FBGEMM_i8_acc32, 83.6
64, 800, 320, FBGEMM_i8_acc16, 79.2
64, 768, 512, MKL_fp32, 46.9
64, 768, 512, FBGEMM_i8_acc32, 88.1
64, 768, 512, FBGEMM_i8_acc16, 89.2
16, 256, 512, MKL_fp32, 27.6
16, 256, 512, FBGEMM_i8_acc32, 43.5
16, 256, 512, FBGEMM_i8_acc16, 54.1
128, 128, 128, MKL_fp32, 30.1
128, 128, 128, FBGEMM_i8_acc32, 42.1
128, 128, 128, FBGEMM_i8_acc16, 40.6
256, 512, 256, MKL_fp32, 44.8
256, 512, 256, FBGEMM_i8_acc32, 91.7
256, 512, 256, FBGEMM_i8_acc16, 91.1
1024, 1024, 1024, MKL_fp32, 48.8
1024, 1024, 1024, FBGEMM_i8_acc32, 97.0
1024, 1024, 1024, FBGEMM_i8_acc16, 97.6
```
After:
```
M, N, K, Type, GOPS
64, 800, 320, MKL_fp32, 46.2
64, 800, 320, FBGEMM_i8_acc32, 83.5
64, 800, 320, FBGEMM_i8_acc16, 80.8
64, 768, 512, MKL_fp32, 47.2
64, 768, 512, FBGEMM_i8_acc32, 88.5
64, 768, 512, FBGEMM_i8_acc16, 87.3
16, 256, 512, MKL_fp32, 26.0
16, 256, 512, FBGEMM_i8_acc32, 44.0
16, 256, 512, FBGEMM_i8_acc16, 54.5
128, 128, 128, MKL_fp32, 29.6
128, 128, 128, FBGEMM_i8_acc32, 42.2
128, 128, 128, FBGEMM_i8_acc16, 38.5
256, 512, 256, MKL_fp32, 44.3
256, 512, 256, FBGEMM_i8_acc32, 91.1
256, 512, 256, FBGEMM_i8_acc16, 91.0
1024, 1024, 1024, MKL_fp32, 48.7
1024, 1024, 1024, FBGEMM_i8_acc32, 96.6
1024, 1024, 1024, FBGEMM_i8_acc16, 96.5
```
Reviewed By: jspark1105
Differential Revision: D21433384
fbshipit-source-id: d0abd56f454293e159d3fda9d94bc84e011060c8
-rw-r--r-- | src/GenerateI8Depthwise.cc | 6 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16.cc | 6 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 6 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32.cc | 8 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 8 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc | 3 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 9 |
7 files changed, 29 insertions, 17 deletions
diff --git a/src/GenerateI8Depthwise.cc b/src/GenerateI8Depthwise.cc index c63c436..a9509b7 100644 --- a/src/GenerateI8Depthwise.cc +++ b/src/GenerateI8Depthwise.cc @@ -351,7 +351,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } x86::Ymm zero(vreg_id); if (need_zero && (!recompute_zero || !has_pad)) { - e->vxorps(zero, zero, zero); + e->vpxor(zero.xmm(), zero.xmm(), zero.xmm()); } // Assign scalar registers @@ -433,7 +433,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( if (i % 4 == 3 || i == K - 1) { if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) { if (recompute_zero && has_pad) { - e->vxorps(zero, zero, zero); + e->vpxor(zero.xmm(), zero.xmm(), zero.xmm()); } } @@ -465,7 +465,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( asmjit::Imm(r < 2 ? 0x20 : 0x31)); } for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) { - e->vmovaps(c[r], a[r]); + e->vmovdqa(c[r], a[r]); } } } diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index b05d85a..fd4073a 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -19,10 +19,12 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) { - using CRegs = x86::Ymm; + using CRegs = x86::Xmm; + // Take advantage of implicit zeroing out + // i.e., zero out xmm and ymm will be zeroed out too for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { - a->vxorps( + a->vpxor( CRegs(i * colRegs + j), CRegs(i * colRegs + j), CRegs(i * colRegs + j)); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index a456fa6..22e7a30 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,10 +19,12 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) { - using CRegs = x86::Zmm; + using CRegs = x86::Xmm; + // Take advantage of implicit zeroing out + // i.e., zero out xmm and zmm will be zeroed out too for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { - a->vxorps( + a->vpxor( CRegs(i * colRegs + j), CRegs(i * colRegs + j), CRegs(i * colRegs + j)); diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 4a8e759..0b14aab 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -19,10 +19,12 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) { - using CRegs = x86::Ymm; + using CRegs = x86::Xmm; + // Take advantage of implicit zeroing out + // i.e., zero out xmm and ymm will be zeroed out too for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { - a->vxorps( + a->vpxor( CRegs(i * colRegs + j), CRegs(i * colRegs + j), CRegs(i * colRegs + j)); @@ -61,7 +63,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< for (int j = 0; j < colRegs; ++j) { // load B - a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vmovdqa(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); // load A, broadcast and fmas for (int i = 0; i < rowRegs; ++i) { a->vpbroadcastd( diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index dc1437d..c9dfcc9 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -19,10 +19,12 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) { - using CRegs = x86::Zmm; + using CRegs = x86::Xmm; + // Take advantage of implicit zeroing out + // i.e., zero out xmm and zmm will be zeroed out too for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { - a->vxorps( + a->vpxor( CRegs(i * colRegs + j), CRegs(i * colRegs + j), CRegs(i * colRegs + j)); @@ -60,7 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< using CRegs = x86::Zmm; for (int j = 0; j < colRegs; ++j) { // load B - a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vmovdqa32(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); // load A, broadcast and fmas for (int i = 0; i < rowRegs; ++i) { a->vpbroadcastd( diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc index 8db53f3..0d9f295 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -45,7 +45,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< assert(colRegs * (rowRegs + 1) <= 31); for (int j = 0; j < colRegs; ++j) { - a->vmovaps(x86::Zmm(30-j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vmovdqa32( + x86::Zmm(30 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); } for (int i = 0; i < rowRegs; i++) { diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 95377f0..16ee1c6 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -308,11 +308,13 @@ template <int SPATIAL_DIM> void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::initResultRegs( x86::Emitter* a) { if (kLoopIters_ > 0) { + // Take advantage of implicit zeroing out + // i.e., zero out xmm and ymm will be zeroed out too for (int k = 0; k < kLoopIters_; ++k) { - a->vxorps(x86::Ymm(9 - k), x86::Ymm(9 - k), x86::Ymm(9 - k)); + a->vpxor(x86::Xmm(9 - k), x86::Xmm(9 - k), x86::Xmm(9 - k)); } } else { - a->vxorps(x86::Ymm(9), x86::Ymm(9), x86::Ymm(9)); + a->vpxor(x86::Xmm(9), x86::Xmm(9), x86::Xmm(9)); } } @@ -557,7 +559,8 @@ void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::genForSingleOutput( // row offset if (this->needRowOffset_) { - a->vxorps(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_); + a->vpxor( + rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm()); } bool isWidthMiddle = !isLeft && !isRight; |