Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackWeightsForConv.cc')
-rw-r--r--src/PackWeightsForConv.cc151
1 files changed, 128 insertions, 23 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index c811144..192fb00 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -4,6 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
+#include <algorithm>
#include <memory>
#include "fbgemm/Fbgemm.h"
@@ -13,7 +14,8 @@ template <int SPATIAL_DIM, typename T, typename accT>
PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
const conv_param_t<SPATIAL_DIM>& conv_p,
const T* sdata,
- const BlockingFactors* blocking_params) {
+ const BlockingFactors* blocking_params)
+ : conv_param_(conv_p) {
static_assert(
SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported");
@@ -21,50 +23,153 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
- if (SPATIAL_DIM == 3) {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ =
- std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
- W_gconv_packed_ = nullptr;
- } else {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ =
- std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
- }
+ W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
+ conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
break;
}
case optimized_conv_t::groupwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
- matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
+ matrix_op_t::Transpose, conv_p, sdata, nullptr);
+ break;
+ }
+ case optimized_conv_t::pointwise: {
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
+ W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
+ matrix_op_t::Transpose,
+ KDim,
+ NDim,
+ sdata,
+ KDim / conv_p.G,
+ nullptr,
+ conv_p.G,
+ blocking_params);
break;
}
case optimized_conv_t::im2col: {
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
- matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose,
KDim,
NDim,
sdata,
- NDim,
+ KDim / conv_p.G,
nullptr,
conv_p.G,
blocking_params);
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
break;
}
} // switch
}
+template <int SPATIAL_DIM, typename T, typename accT>
+void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
+ if (W_dw_packed_) {
+ W_dw_packed_->unpack(origin_buf);
+ } else if (W_gconv_packed_) {
+ W_gconv_packed_->unpack(origin_buf);
+ } else if (W_im2col_packed_) {
+ W_im2col_packed_->unpack(origin_buf);
+ } else if (W_pointwise_packed_) {
+ W_pointwise_packed_->unpack(origin_buf);
+ } else {
+ assert(false && "At least one packed weights object should exist");
+ }
+}
+
+template <int SPATIAL_DIM, typename T, typename accT>
+bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC &&
+ conv_param_.G == test_conv_p.G &&
+ std::equal(
+ conv_param_.K.begin(),
+ conv_param_.K.end(),
+ test_conv_p.K.begin()) &&
+ std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin()) &&
+ std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin()) &&
+ std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin());
+}
+
+template <int SPATIAL_DIM, typename T, typename accT>
+std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ std::string msg = "";
+
+ auto combineStr = [](std::string id, std::string str1, std::string str2) {
+ std::string out = id + std::string(" ");
+ out += str1;
+ out += std::string(" vs ") + str2;
+ out += std::string(";");
+ return out;
+ };
+
+ auto combineInt = [&combineStr](std::string id, int int1, int int2) {
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
+ };
+
+ if (conv_param_.IC != test_conv_p.IC) {
+ msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.OC != test_conv_p.OC) {
+ msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.G != test_conv_p.G) {
+ msg += combineInt("groups", conv_param_.G, test_conv_p.G);
+ }
+
+ if (!std::equal(
+ conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) {
+ msg += combineStr(
+ "kernel",
+ arrayToString<SPATIAL_DIM>(conv_param_.K),
+ arrayToString<SPATIAL_DIM>(test_conv_p.K));
+ }
+
+ if (!std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin())) {
+ msg += combineStr(
+ "stride",
+ arrayToString<SPATIAL_DIM>(conv_param_.stride),
+ arrayToString<SPATIAL_DIM>(test_conv_p.stride));
+ }
+
+ if (!std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin())) {
+ msg += combineStr(
+ "pad",
+ arrayToString<2 * SPATIAL_DIM>(conv_param_.pad),
+ arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad));
+ }
+
+ if (!std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin())) {
+ msg += combineStr(
+ "dilation",
+ arrayToString<SPATIAL_DIM>(conv_param_.dilation),
+ arrayToString<SPATIAL_DIM>(test_conv_p.dilation));
+ }
+
+ return msg;
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;