diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-07-17 05:24:07 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-07-17 05:29:18 +0300 |
commit | 6e903d5f7fc9f0bdc716767409d0b994d1de6cf6 (patch) | |
tree | 2178e4b281eb6f91daa17a48642b51f80350cd00 /src | |
parent | 931b3b71f021401dede9f35f3920eeb1e98c4c09 (diff) |
While calling fbgemmConv with packed weights, packed weights should be compliant with convolution parameters
Summary: This is to detect inadvertent calling for fbgemmConv with one set of conv parameters while packing was done with another set of parameters.
Reviewed By: jspark1105
Differential Revision: D16269293
fbshipit-source-id: 9a166f5298d8246047e40fc880dd87e1037e0456
Diffstat (limited to 'src')
-rw-r--r-- | src/FbgemmConv.cc | 7 | ||||
-rw-r--r-- | src/PackWeightsForConv.cc | 24 |
2 files changed, 31 insertions, 0 deletions
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 5db63f6..be54693 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -58,6 +58,13 @@ int fbgemmConv( static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); + + if (!packed_weights.isPackingCompliant(conv_p)) { + throw std::logic_error( + "[FBGEMM_CONV_ERROR] Prepacked weights can't be used" + " with these convolution parameters!"); + } + switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { case optimized_conv_t::depthwise: { // 2D and 3D depthwise fast path diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 085adc0..ccb97fd 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -4,6 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include <algorithm> #include <memory> #include "fbgemm/Fbgemm.h" @@ -81,6 +82,29 @@ void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { } } +template <int SPATIAL_DIM, typename T, typename accT> +bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && + conv_param_.G == test_conv_p.G && + std::equal( + conv_param_.K.begin(), + conv_param_.K.end(), + test_conv_p.K.begin()) && + std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin()) && + std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin()) && + std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin()); +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; |