diff options
-rw-r--r-- | include/fbgemm/Fbgemm.h | 6 | ||||
-rw-r--r-- | src/PackWeightsForConv.cc | 15 | ||||
-rw-r--r-- | test/UniConvTest.cc (renamed from test/UniConvPackingTest.cc) | 63 |
3 files changed, 81 insertions, 3 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 9ee25b5..302af51 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -597,6 +597,12 @@ class FBGEMM_API PackWeightsForConv { return W_gconv_packed_; } + /** + * @brief Unpack packed matric into origin_buf (Used for the serialization to + * recover weight matrix). + */ + void unpack(T* origin_buf); + private: // Packed weights if we use im2col based convolution implementation std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_; diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 78379af..e16843c 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -65,6 +65,21 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( } // switch } +template <int SPATIAL_DIM, typename T, typename accT> +void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { + if (W_dw_2D_packed_) { + W_dw_2D_packed_->unpack(origin_buf); + } else if (W_dw_3D_packed_) { + W_dw_3D_packed_->unpack(origin_buf); + } else if (W_gconv_packed_) { + W_gconv_packed_->unpack(origin_buf); + } else if (W_im2col_packed_) { + W_im2col_packed_->unpack(origin_buf); + } else { + assert(false && "At least one packed weights object should exist"); + } +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; diff --git a/test/UniConvPackingTest.cc b/test/UniConvTest.cc index 77552af..2b110dd 100644 --- a/test/UniConvPackingTest.cc +++ b/test/UniConvTest.cc @@ -23,7 +23,7 @@ using namespace fbgemm; namespace { // tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad -class convPackingTest +class uniConvTest : public testing::TestWithParam< tuple<int, int, int, int, int, int, int, int, int, int>> {}; @@ -31,7 +31,7 @@ class convPackingTest INSTANTIATE_TEST_CASE_P( InstantiationName, - convPackingTest, + uniConvTest, ::testing::Combine( ::testing::ValuesIn({1, 2}), // MB ::testing::ValuesIn({16, 32}), // IC @@ -47,7 +47,7 @@ INSTANTIATE_TEST_CASE_P( /** * Test for conv packing */ -TEST_P(convPackingTest, packingTest) { +TEST_P(uniConvTest, packingTest) { int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); @@ -146,3 +146,60 @@ TEST_P(convPackingTest, packingTest) { } } } + +/** + * Test for packing/unpacking + */ +TEST_P(uniConvTest, packUnpackTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector<int8_t> Bint8_2d_unpacked( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + packedB_2D.unpack(Bint8_2d_unpacked.data()); + + ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked) + << "Original and unpacked data elements are not the same [2D]"; + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + aligned_vector<int8_t> Bint8_3d_unpacked( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + packedB_3D.unpack(Bint8_3d_unpacked.data()); + + ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked) + << "Original and unpacked data elements are not the same [3D]"; +} |