/* * Copyright (c) Facebook, Inc. and its affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include "fbgemm/Fbgemm.h" namespace fbgemm { template PackWeightsForConv::PackWeightsForConv( const conv_param_t& conv_p, const T* sdata, const BlockingFactors* blocking_params) : conv_param_(conv_p) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); // Note: The following logic should *exactly* match with what we have in // FbgemmConv.cc switch (ConvFastPath(conv_p)) { case optimized_conv_t::depthwise: { if (SPATIAL_DIM == 3) { W_im2col_packed_ = nullptr; W_dw_2D_packed_ = nullptr; W_dw_3D_packed_ = std::make_shared(conv_p.G, sdata); W_gconv_packed_ = nullptr; } else { W_im2col_packed_ = nullptr; W_dw_2D_packed_ = std::make_shared(conv_p.G, sdata); W_dw_3D_packed_ = nullptr; W_gconv_packed_ = nullptr; } break; } case optimized_conv_t::groupwise: { W_im2col_packed_ = nullptr; W_dw_2D_packed_ = nullptr; W_dw_3D_packed_ = nullptr; W_gconv_packed_ = std::make_shared>( matrix_op_t::Transpose, conv_p, sdata, nullptr); break; } case optimized_conv_t::pointwise: { W_im2col_packed_ = nullptr; W_dw_2D_packed_ = nullptr; W_dw_3D_packed_ = nullptr; W_gconv_packed_ = nullptr; int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_pointwise_packed_ = std::make_shared>( matrix_op_t::Transpose, KDim, NDim, sdata, KDim / conv_p.G, nullptr, conv_p.G, blocking_params); break; } case optimized_conv_t::im2col: { int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_im2col_packed_ = std::make_shared>( matrix_op_t::Transpose, KDim, NDim, sdata, KDim / conv_p.G, nullptr, conv_p.G, blocking_params); W_dw_2D_packed_ = nullptr; W_dw_3D_packed_ = nullptr; W_gconv_packed_ = nullptr; break; } } // switch } template void PackWeightsForConv::unpack(T* origin_buf) { if (W_dw_2D_packed_) { W_dw_2D_packed_->unpack(origin_buf); } else if (W_dw_3D_packed_) { W_dw_3D_packed_->unpack(origin_buf); } else if (W_gconv_packed_) { W_gconv_packed_->unpack(origin_buf); } else if (W_im2col_packed_) { W_im2col_packed_->unpack(origin_buf); } else if (W_pointwise_packed_) { W_pointwise_packed_->unpack(origin_buf); } else { assert(false && "At least one packed weights object should exist"); } } template bool PackWeightsForConv::isPackingCompliant( const conv_param_t& test_conv_p) { return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && conv_param_.G == test_conv_p.G && std::equal( conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin()) && std::equal( conv_param_.stride.begin(), conv_param_.stride.end(), test_conv_p.stride.begin()) && std::equal( conv_param_.pad.begin(), conv_param_.pad.end(), test_conv_p.pad.begin()) && std::equal( conv_param_.dilation.begin(), conv_param_.dilation.end(), test_conv_p.dilation.begin()); } template std::string PackWeightsForConv::mismatchingParams( const conv_param_t& test_conv_p) { std::string msg = ""; auto combineStr = [](std::string id, std::string str1, std::string str2) { std::string out = id + std::string(" "); out += str1; out += std::string(" vs ") + str2; out += std::string(";"); return out; }; auto combineInt = [&combineStr](std::string id, int int1, int int2) { return combineStr(id, std::to_string(int1), std::to_string(int2)); }; if (conv_param_.IC != test_conv_p.IC) { msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC); } if (conv_param_.OC != test_conv_p.OC) { msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC); } if (conv_param_.G != test_conv_p.G) { msg += combineInt("groups", conv_param_.G, test_conv_p.G); } if (!std::equal( conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { msg += combineStr( "kernel", arrayToString(conv_param_.K), arrayToString(test_conv_p.K)); } if (!std::equal( conv_param_.stride.begin(), conv_param_.stride.end(), test_conv_p.stride.begin())) { msg += combineStr( "stride", arrayToString(conv_param_.stride), arrayToString(test_conv_p.stride)); } if (!std::equal( conv_param_.pad.begin(), conv_param_.pad.end(), test_conv_p.pad.begin())) { msg += combineStr( "pad", arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); } if (!std::equal( conv_param_.dilation.begin(), conv_param_.dilation.end(), test_conv_p.dilation.begin())) { msg += combineStr( "dilation", arrayToString(conv_param_.dilation), arrayToString(test_conv_p.dilation)); } return msg; } template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; } // namespace fbgemm