diff options
Diffstat (limited to 'src/PackWeightsForConv.cc')
-rw-r--r-- | src/PackWeightsForConv.cc | 151 |
1 files changed, 128 insertions, 23 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index c811144..192fb00 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -4,6 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include <algorithm> #include <memory> #include "fbgemm/Fbgemm.h" @@ -13,7 +14,8 @@ template <int SPATIAL_DIM, typename T, typename accT> PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( const conv_param_t<SPATIAL_DIM>& conv_p, const T* sdata, - const BlockingFactors* blocking_params) { + const BlockingFactors* blocking_params) + : conv_param_(conv_p) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); @@ -21,50 +23,153 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { case optimized_conv_t::depthwise: { - if (SPATIAL_DIM == 3) { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = - std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata); - W_gconv_packed_ = nullptr; - } else { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = - std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata); - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; - } + W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( + conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata); break; } case optimized_conv_t::groupwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; W_gconv_packed_ = std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( - matrix_op_t::NoTranspose, conv_p, sdata, nullptr); + matrix_op_t::Transpose, conv_p, sdata, nullptr); + break; + } + case optimized_conv_t::pointwise: { + int NDim = conv_p.OC / conv_p.G; + int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; + W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>( + matrix_op_t::Transpose, + KDim, + NDim, + sdata, + KDim / conv_p.G, + nullptr, + conv_p.G, + blocking_params); break; } case optimized_conv_t::im2col: { int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>( - matrix_op_t::NoTranspose, + matrix_op_t::Transpose, KDim, NDim, sdata, - NDim, + KDim / conv_p.G, nullptr, conv_p.G, blocking_params); - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; break; } } // switch } +template <int SPATIAL_DIM, typename T, typename accT> +void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { + if (W_dw_packed_) { + W_dw_packed_->unpack(origin_buf); + } else if (W_gconv_packed_) { + W_gconv_packed_->unpack(origin_buf); + } else if (W_im2col_packed_) { + W_im2col_packed_->unpack(origin_buf); + } else if (W_pointwise_packed_) { + W_pointwise_packed_->unpack(origin_buf); + } else { + assert(false && "At least one packed weights object should exist"); + } +} + +template <int SPATIAL_DIM, typename T, typename accT> +bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && + conv_param_.G == test_conv_p.G && + std::equal( + conv_param_.K.begin(), + conv_param_.K.end(), + test_conv_p.K.begin()) && + std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin()) && + std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin()) && + std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin()); +} + +template <int SPATIAL_DIM, typename T, typename accT> +std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + std::string msg = ""; + + auto combineStr = [](std::string id, std::string str1, std::string str2) { + std::string out = id + std::string(" "); + out += str1; + out += std::string(" vs ") + str2; + out += std::string(";"); + return out; + }; + + auto combineInt = [&combineStr](std::string id, int int1, int int2) { + return combineStr(id, std::to_string(int1), std::to_string(int2)); + }; + + if (conv_param_.IC != test_conv_p.IC) { + msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.OC != test_conv_p.OC) { + msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.G != test_conv_p.G) { + msg += combineInt("groups", conv_param_.G, test_conv_p.G); + } + + if (!std::equal( + conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { + msg += combineStr( + "kernel", + arrayToString<SPATIAL_DIM>(conv_param_.K), + arrayToString<SPATIAL_DIM>(test_conv_p.K)); + } + + if (!std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin())) { + msg += combineStr( + "stride", + arrayToString<SPATIAL_DIM>(conv_param_.stride), + arrayToString<SPATIAL_DIM>(test_conv_p.stride)); + } + + if (!std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin())) { + msg += combineStr( + "pad", + arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), + arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); + } + + if (!std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin())) { + msg += combineStr( + "dilation", + arrayToString<SPATIAL_DIM>(conv_param_.dilation), + arrayToString<SPATIAL_DIM>(test_conv_p.dilation)); + } + + return msg; +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; |