Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-08-12 20:42:13 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-08-12 20:50:35 +0300
commit1be081503e9b765fc9e18b50b94a9f24bd79025f (patch)
tree0dc133830d4a96954307385b9a1fb0032cb088ba /src
parentaceefe3e0cc59c6754c90d5f5ffe726666b1d0ac (diff)
fix error message (#117)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/117 Fixes error message with mismatching parameters. Before: ``` [FBGEMM_CONV_ERROR] Prepacked weights can't be used with these convolution parameters! ``` After ``` [FBGEMM_CONV_ERROR] Convolution parameters mismatch between pre-packed weights and conv invocation! stride [1, 1] vs [2, 1]; Please pack weights using the same parameters with which convolution operation is invoked! ``` Reviewed By: jianyuh Differential Revision: D16749007 fbshipit-source-id: 7a3083f2955b798ae28d25ce1963c7de63654551
Diffstat (limited to 'src')
-rw-r--r--src/FbgemmConv.cc11
-rw-r--r--src/PackWeightsForConv.cc68
2 files changed, 76 insertions, 3 deletions
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index 027e6c5..33d1535 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -73,9 +73,14 @@ int fbgemmConv(
"Only 2D and 3D convolutions are supported");
if (!packed_weights.isPackingCompliant(conv_p)) {
- throw std::logic_error(
- "[FBGEMM_CONV_ERROR] Prepacked weights can't be used"
- " with these convolution parameters!");
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] Convolution parameters "
+ "mismatch between pre-packed weights and conv invocation! ";
+ msg += packed_weights.mismatchingParams(conv_p);
+ msg += std::string(
+ " Please pack weights using the same parameters "
+ "with which convolution operation is invoked!");
+ throw std::logic_error(msg);
}
switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 25b04af..44f210e 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -125,6 +125,74 @@ bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
test_conv_p.dilation.begin());
}
+template <int SPATIAL_DIM, typename T, typename accT>
+std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ std::string msg = "";
+
+ auto combineStr = [](std::string id, std::string str1, std::string str2) {
+ std::string out = id + std::string(" ");
+ out += str1;
+ out += std::string(" vs ") + str2;
+ out += std::string(";");
+ return out;
+ };
+
+ auto combineInt = [&combineStr](std::string id, int int1, int int2) {
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
+ };
+
+ if (conv_param_.IC != test_conv_p.IC) {
+ msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.OC != test_conv_p.OC) {
+ msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.G != test_conv_p.G) {
+ msg += combineInt("groups", conv_param_.G, test_conv_p.G);
+ }
+
+ if (!std::equal(
+ conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) {
+ msg += combineStr(
+ "kernel",
+ arrayToString<SPATIAL_DIM>(conv_param_.K),
+ arrayToString<SPATIAL_DIM>(test_conv_p.K));
+ }
+
+ if (!std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin())) {
+ msg += combineStr(
+ "stride",
+ arrayToString<SPATIAL_DIM>(conv_param_.stride),
+ arrayToString<SPATIAL_DIM>(test_conv_p.stride));
+ }
+
+ if (!std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin())) {
+ msg += combineStr(
+ "pad",
+ arrayToString<2 * SPATIAL_DIM>(conv_param_.pad),
+ arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad));
+ }
+
+ if (!std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin())) {
+ msg += combineStr(
+ "dilation",
+ arrayToString<SPATIAL_DIM>(conv_param_.dilation),
+ arrayToString<SPATIAL_DIM>(test_conv_p.dilation));
+ }
+
+ return msg;
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;