From 6e903d5f7fc9f0bdc716767409d0b994d1de6cf6 Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Tue, 16 Jul 2019 19:24:07 -0700 Subject: While calling fbgemmConv with packed weights, packed weights should be compliant with convolution parameters Summary: This is to detect inadvertent calling for fbgemmConv with one set of conv parameters while packing was done with another set of parameters. Reviewed By: jspark1105 Differential Revision: D16269293 fbshipit-source-id: 9a166f5298d8246047e40fc880dd87e1037e0456 --- include/fbgemm/Fbgemm.h | 6 ++++ src/FbgemmConv.cc | 7 ++++ src/PackWeightsForConv.cc | 24 ++++++++++++++ test/UniConvTest.cc | 84 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+) diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 9cee28f..bb9f2ca 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -613,6 +613,12 @@ class FBGEMM_API PackWeightsForConv { return conv_param_.G; } + /** + * @brief Returns true if the packed weights would work for the given + * convolution parameters, and false otherwise + */ + bool isPackingCompliant(const conv_param_t& conv_p); + /** * @brief Unpack packed matric into origin_buf (Used for the serialization to * recover weight matrix). diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 5db63f6..be54693 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -58,6 +58,13 @@ int fbgemmConv( static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); + + if (!packed_weights.isPackingCompliant(conv_p)) { + throw std::logic_error( + "[FBGEMM_CONV_ERROR] Prepacked weights can't be used" + " with these convolution parameters!"); + } + switch (ConvFastPath(conv_p)) { case optimized_conv_t::depthwise: { // 2D and 3D depthwise fast path diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 085adc0..ccb97fd 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -4,6 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include #include #include "fbgemm/Fbgemm.h" @@ -81,6 +82,29 @@ void PackWeightsForConv::unpack(T* origin_buf) { } } +template +bool PackWeightsForConv::isPackingCompliant( + const conv_param_t& test_conv_p) { + return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && + conv_param_.G == test_conv_p.G && + std::equal( + conv_param_.K.begin(), + conv_param_.K.end(), + test_conv_p.K.begin()) && + std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin()) && + std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin()) && + std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin()); +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc index 2b110dd..893afcb 100644 --- a/test/UniConvTest.cc +++ b/test/UniConvTest.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -203,3 +204,86 @@ TEST_P(uniConvTest, packUnpackTest) { ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked) << "Original and unpacked data elements are not the same [3D]"; } + +TEST(uniConvTest, cornerCases) { + int stride = 1; + conv_param_t<2> conv_p_2d( + 1, // mini-batch + 16, // input channels + 32, // output channels + {28, 28}, // input height/width + 4, // groups + {3, 3}, // kernel height/width + {stride, stride}, // strides + {1, 1, 1, 1}); // padding + + int kernel_dim_2d = conv_p_2d.K[0] * conv_p_2d.K[1]; + + aligned_vector Aint8( + conv_p_2d.MB * conv_p_2d.IN_DIM[0] * conv_p_2d.IN_DIM[1] * conv_p_2d.IC); + aligned_vector Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector Cint32_fb( + conv_p_2d.MB * conv_p_2d.OUT_DIM[0] * conv_p_2d.OUT_DIM[1] * + conv_p_2d.OC); + aligned_vector Cint8_fb(Cint32_fb.size(), 0); + + // A matrix (input activations) + randFill(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + // B matrix (weights) + randFill(Bint8_2d, -4, 4); + aligned_vector Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_point = 5; + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + vector col_offsets(conv_p_2d.OC); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsets.data(), + nullptr, // bias + conv_p_2d.OC, + conv_p_2d.G); + + try { + conv_p_2d.stride[0] = 2; + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); + } catch (std::logic_error const& err) { + std::string s(err.what()); + EXPECT_TRUE(s.rfind("[FBGEMM_CONV_ERROR]", 0) == 0); + } + + // reset + conv_p_2d.stride[0] = stride; + // this should run fine + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); +} -- cgit v1.2.3