1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cpuinfo.h>
#include <iomanip>
#include <stdexcept>
#include <type_traits>
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
namespace fbgemm {
template <typename PT, typename inpType, typename accType>
PackMatrix<PT, inpType, accType>::PackMatrix(
int32_t rows,
int32_t cols,
inpType* buf,
int groups,
const BlockingFactors* params)
: buf_(buf), nrows_(rows), ncols_(cols), G_(groups) {
bufAllocatedHere_ = false;
blocking_params = params;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
}
template <typename PT, typename inpType, typename accType>
int PackMatrix<PT, inpType, accType>::packedBufferSize(
int rows,
int cols,
const BlockingFactors* params) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
!fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
int MCB, KCB, NCB;
if (params) {
MCB = params->MCB;
NCB = params->NCB;
KCB = params->KCB;
} else {
if (fbgemmHasAvx512VnniSupport()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
} else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
} else {
// AVX2
MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
}
}
if (isA()) {
return MCB * KCB;
} else {
int rowBlock = KCB;
int colBlock = NCB;
return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
(((cols + colBlock - 1) / colBlock) * colBlock);
}
return -1;
}
// int32 accumulation
template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>;
template class PackMatrix<
PackAWithRowOffset<uint8_t, int32_t>,
uint8_t,
int32_t>;
template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
template class PackMatrix<
PackAWithIm2Col<uint8_t, int32_t, 3>,
uint8_t,
int32_t>;
template class PackMatrix<
PackAWithQuantRowOffset<uint8_t, int32_t>,
uint8_t,
int32_t>;
template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
// int16 accumulation
template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
template class PackMatrix<
PackAWithIm2Col<uint8_t, int16_t, 3>,
uint8_t,
int16_t>;
template class PackMatrix<
PackAWithRowOffset<uint8_t, int16_t>,
uint8_t,
int16_t>;
template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>;
template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>;
} // namespace fbgemm
|