From 1be081503e9b765fc9e18b50b94a9f24bd79025f Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Mon, 12 Aug 2019 10:42:13 -0700 Subject: fix error message (#117) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/117 Fixes error message with mismatching parameters. Before: ``` [FBGEMM_CONV_ERROR] Prepacked weights can't be used with these convolution parameters! ``` After ``` [FBGEMM_CONV_ERROR] Convolution parameters mismatch between pre-packed weights and conv invocation! stride [1, 1] vs [2, 1]; Please pack weights using the same parameters with which convolution operation is invoked! ``` Reviewed By: jianyuh Differential Revision: D16749007 fbshipit-source-id: 7a3083f2955b798ae28d25ce1963c7de63654551 --- include/fbgemm/Fbgemm.h | 5 ++++ include/fbgemm/Utils.h | 11 ++++++++ src/FbgemmConv.cc | 11 +++++--- src/PackWeightsForConv.cc | 68 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 3 deletions(-) diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 68963fa..543f1cb 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -627,6 +627,11 @@ class FBGEMM_API PackWeightsForConv { */ bool isPackingCompliant(const conv_param_t& conv_p); + /** + * @brief Returns a string of mismatching parameters + */ + std::string mismatchingParams(const conv_param_t& conv_p); + /** * @brief Unpack packed matric into origin_buf (Used for the serialization to * recover weight matrix). diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 3f8522b..0b738c7 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #pragma once +#include #include #include #include "FbgemmBuild.h" @@ -121,6 +122,16 @@ struct FBGEMM_API BlockingFactors { int NCB; }; +template +FBGEMM_API std::string arrayToString(const std::array& inp) { + std::string out = "["; + for (int i = 0; i < SIZE; ++i) { + out += std::to_string(inp[i]); + out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); + } + return out; +} + template FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { constexpr bool is_32bit = std::is_same::value; diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 027e6c5..33d1535 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -73,9 +73,14 @@ int fbgemmConv( "Only 2D and 3D convolutions are supported"); if (!packed_weights.isPackingCompliant(conv_p)) { - throw std::logic_error( - "[FBGEMM_CONV_ERROR] Prepacked weights can't be used" - " with these convolution parameters!"); + std::string msg = + "[FBGEMM_CONV_ERROR] Convolution parameters " + "mismatch between pre-packed weights and conv invocation! "; + msg += packed_weights.mismatchingParams(conv_p); + msg += std::string( + " Please pack weights using the same parameters " + "with which convolution operation is invoked!"); + throw std::logic_error(msg); } switch (ConvFastPath(conv_p)) { diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 25b04af..44f210e 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -125,6 +125,74 @@ bool PackWeightsForConv::isPackingCompliant( test_conv_p.dilation.begin()); } +template +std::string PackWeightsForConv::mismatchingParams( + const conv_param_t& test_conv_p) { + std::string msg = ""; + + auto combineStr = [](std::string id, std::string str1, std::string str2) { + std::string out = id + std::string(" "); + out += str1; + out += std::string(" vs ") + str2; + out += std::string(";"); + return out; + }; + + auto combineInt = [&combineStr](std::string id, int int1, int int2) { + return combineStr(id, std::to_string(int1), std::to_string(int2)); + }; + + if (conv_param_.IC != test_conv_p.IC) { + msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.OC != test_conv_p.OC) { + msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.G != test_conv_p.G) { + msg += combineInt("groups", conv_param_.G, test_conv_p.G); + } + + if (!std::equal( + conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { + msg += combineStr( + "kernel", + arrayToString(conv_param_.K), + arrayToString(test_conv_p.K)); + } + + if (!std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin())) { + msg += combineStr( + "stride", + arrayToString(conv_param_.stride), + arrayToString(test_conv_p.stride)); + } + + if (!std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin())) { + msg += combineStr( + "pad", + arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), + arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); + } + + if (!std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin())) { + msg += combineStr( + "dilation", + arrayToString(conv_param_.dilation), + arrayToString(test_conv_p.dilation)); + } + + return msg; +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; -- cgit v1.2.3