Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2020-04-05 12:30:57 +0300
committerYoung Jin Kim <youki@microsoft.com>2020-04-05 12:30:57 +0300
commit383591912206b67950fe31248a4110de8ab06c8d (patch)
treefffe2af30c83871ac9607bf826c3ce834c41eeac
parentb842b37bcfcc42bae5098349e08a7174eb178f4e (diff)
Add some more sparse experimentsyouki/testsparse
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc24
1 files changed, 13 insertions, 11 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 6141cc6..6ac51e8 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -63,26 +63,28 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
x86::Zmm res1 = x86::zmm28;
//std::cout << "sparse: " << sparse << std::endl;
- x86::Gp zero = a->zax();
- //x86::Gp bloaded = a->zax();
+ //x86::Gp zero = a->zax();
+ x86::Gp bloaded = a->zax();
using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
if(sparse) {
// load B
- a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
- //a->mov(bloaded, 0);
+ //a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->mov(bloaded, 0);
for(int i = 0; i < rowRegs; ++i) {
asmjit::Label SkipFma = a->newLabel();
- //asmjit::Label SkipBload = a->newLabel();
- a->mov(zero, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
- a->test(zero, zero);
+ asmjit::Label SkipBload = a->newLabel();
+ //a->mov(zero, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ //a->test(zero, zero);
+ a->cmp(x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)), 0);
a->je(SkipFma);
- //a->cmp(bloaded, 0);
- //a->jne(SkipBload);
- //a->inc(bloaded);
- //a->bind(SkipBload);
+ a->cmp(bloaded, 0);
+ a->jne(SkipBload);
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->inc(bloaded);
+ a->bind(SkipBload);
a->vpbroadcastd(AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
a->vpmaddubsw(res1, AReg, BReg);