blob: aa9f0ff9a75c9299b872f9c1cfd16f7acf1fb2d8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
#pragma once
#include "intrinsics.h"
#include <exception>
namespace intgemm {
class JustUnquantizeC {
public:
JustUnquantizeC(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {}
class OnSSE2 {
public:
SSE2 explicit OnSSE2(const JustUnquantizeC &from)
: C_(from.C_), unquant_mult_(_mm_set1_ps(from.unquant_mult_)) {
assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
}
SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
*reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_);
*reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_);
}
private:
float *C_;
__m128 unquant_mult_;
};
class OnAVX2 {
public:
AVX2 explicit OnAVX2(const JustUnquantizeC &from)
: C_(from.C_), unquant_mult_(_mm256_set1_ps(from.unquant_mult_)) {
assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
}
AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
*reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_);
}
private:
float *C_;
__m256 unquant_mult_;
};
private:
float *C_;
float unquant_mult_;
};
} //Namespace
|