1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <memory>
#include "fbgemm/Fbgemm.h"
namespace fbgemm {
template <int SPATIAL_DIM, typename T, typename accT>
PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
const conv_param_t<SPATIAL_DIM>& conv_p,
const T* sdata,
const BlockingFactors* blocking_params) {
static_assert(
SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported");
// Note: The following logic should *exactly* match with what we have in
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
if (SPATIAL_DIM == 3) {
W_im2col_packed_ = nullptr;
W_dw_2D_packed_ = nullptr;
W_dw_3D_packed_ =
std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
W_gconv_packed_ = nullptr;
} else {
W_im2col_packed_ = nullptr;
W_dw_2D_packed_ =
std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
W_dw_3D_packed_ = nullptr;
W_gconv_packed_ = nullptr;
}
break;
}
case optimized_conv_t::groupwise: {
W_im2col_packed_ = nullptr;
W_dw_2D_packed_ = nullptr;
W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
break;
}
case optimized_conv_t::im2col: {
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
matrix_op_t::NoTranspose,
KDim,
NDim,
sdata,
NDim,
nullptr,
conv_p.G,
blocking_params);
W_dw_2D_packed_ = nullptr;
W_dw_3D_packed_ = nullptr;
W_gconv_packed_ = nullptr;
break;
}
} // switch
}
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;
} // namespace fbgemm
|