diff options
Diffstat (limited to 'src/PackWeightsForConv.cc')
-rw-r--r-- | src/PackWeightsForConv.cc | 33 |
1 files changed, 5 insertions, 28 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 44f210e..192fb00 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -23,35 +23,17 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { case optimized_conv_t::depthwise: { - if (SPATIAL_DIM == 3) { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = - std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata); - W_gconv_packed_ = nullptr; - } else { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = - std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata); - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; - } + W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( + conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata); break; } case optimized_conv_t::groupwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; W_gconv_packed_ = std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( matrix_op_t::Transpose, conv_p, sdata, nullptr); break; } case optimized_conv_t::pointwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>( @@ -77,9 +59,6 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( nullptr, conv_p.G, blocking_params); - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; break; } } // switch @@ -87,10 +66,8 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( template <int SPATIAL_DIM, typename T, typename accT> void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { - if (W_dw_2D_packed_) { - W_dw_2D_packed_->unpack(origin_buf); - } else if (W_dw_3D_packed_) { - W_dw_3D_packed_->unpack(origin_buf); + if (W_dw_packed_) { + W_dw_packed_->unpack(origin_buf); } else if (W_gconv_packed_) { W_gconv_packed_->unpack(origin_buf); } else if (W_im2col_packed_) { @@ -139,7 +116,7 @@ std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( }; auto combineInt = [&combineStr](std::string id, int int1, int int2) { - return combineStr(id, std::to_string(int1), std::to_string(int2)); + return combineStr(id, std::to_string(int1), std::to_string(int2)); }; if (conv_param_.IC != test_conv_p.IC) { |