Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackWeightsForConv.cc')
-rw-r--r--src/PackWeightsForConv.cc33
1 files changed, 5 insertions, 28 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 44f210e..192fb00 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -23,35 +23,17 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
- if (SPATIAL_DIM == 3) {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ =
- std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
- W_gconv_packed_ = nullptr;
- } else {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ =
- std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
- }
+ W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
+ conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
break;
}
case optimized_conv_t::groupwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
matrix_op_t::Transpose, conv_p, sdata, nullptr);
break;
}
case optimized_conv_t::pointwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
@@ -77,9 +59,6 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
nullptr,
conv_p.G,
blocking_params);
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
break;
}
} // switch
@@ -87,10 +66,8 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
template <int SPATIAL_DIM, typename T, typename accT>
void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
- if (W_dw_2D_packed_) {
- W_dw_2D_packed_->unpack(origin_buf);
- } else if (W_dw_3D_packed_) {
- W_dw_3D_packed_->unpack(origin_buf);
+ if (W_dw_packed_) {
+ W_dw_packed_->unpack(origin_buf);
} else if (W_gconv_packed_) {
W_gconv_packed_->unpack(origin_buf);
} else if (W_im2col_packed_) {
@@ -139,7 +116,7 @@ std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
};
auto combineInt = [&combineStr](std::string id, int int1, int int2) {
- return combineStr(id, std::to_string(int1), std::to_string(int2));
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
};
if (conv_param_.IC != test_conv_p.IC) {