Welcome to mirror list, hosted at ThFree Co, Russian Federation.

cops.h - github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cops.h
blob: 196f705cdc99706ca9cd0ade152b9000c6b595ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#pragma once
#include "intrinsics.h"

#include <exception>

namespace intgemm {

class JustUnquantizeC {
public:
 JustUnquantizeC(float *C, float unquant_mult);

 SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result);
 AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result);

private:
  SSE2 void InitRegisterSSE(float unquant_mult);
  AVX2 void InitRegisterAVX2(float unquant_mult);

  float *C_;
  __m128 unquant_mult_128_; // Registers
  __m256 unquant_mult_256_;
};

SSE2 void JustUnquantizeC::InitRegisterSSE(float unquant_mult) {
  assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128) == 0);
  unquant_mult_128_ = _mm_set1_ps(unquant_mult);
}

AVX2 void JustUnquantizeC::InitRegisterAVX2(float unquant_mult) {
  assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256) == 0);
  unquant_mult_256_ = _mm256_set1_ps(unquant_mult);
}

JustUnquantizeC::JustUnquantizeC(float *C, float unquant_mult) : C_(C) {
  //We need both to make sure our tests pass
  //Some of the assertions might give false positives on SSE2/3
  InitRegisterSSE(unquant_mult);
  if (__builtin_cpu_supports("avx2")) {
    InitRegisterAVX2(unquant_mult);
  }
}


SSE2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result){
  *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_128_);
  *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_128_);
}
AVX2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
   *reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_256_);
}
} //Namespace