Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/fbgemm/Fbgemm.h6
-rw-r--r--src/PackWeightsForConv.cc15
-rw-r--r--test/UniConvTest.cc (renamed from test/UniConvPackingTest.cc)63
3 files changed, 81 insertions, 3 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 9ee25b5..302af51 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -597,6 +597,12 @@ class FBGEMM_API PackWeightsForConv {
return W_gconv_packed_;
}
+ /**
+ * @brief Unpack packed matric into origin_buf (Used for the serialization to
+ * recover weight matrix).
+ */
+ void unpack(T* origin_buf);
+
private:
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 78379af..e16843c 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -65,6 +65,21 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
} // switch
}
+template <int SPATIAL_DIM, typename T, typename accT>
+void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
+ if (W_dw_2D_packed_) {
+ W_dw_2D_packed_->unpack(origin_buf);
+ } else if (W_dw_3D_packed_) {
+ W_dw_3D_packed_->unpack(origin_buf);
+ } else if (W_gconv_packed_) {
+ W_gconv_packed_->unpack(origin_buf);
+ } else if (W_im2col_packed_) {
+ W_im2col_packed_->unpack(origin_buf);
+ } else {
+ assert(false && "At least one packed weights object should exist");
+ }
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;
diff --git a/test/UniConvPackingTest.cc b/test/UniConvTest.cc
index 77552af..2b110dd 100644
--- a/test/UniConvPackingTest.cc
+++ b/test/UniConvTest.cc
@@ -23,7 +23,7 @@ using namespace fbgemm;
namespace {
// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
-class convPackingTest
+class uniConvTest
: public testing::TestWithParam<
tuple<int, int, int, int, int, int, int, int, int, int>> {};
@@ -31,7 +31,7 @@ class convPackingTest
INSTANTIATE_TEST_CASE_P(
InstantiationName,
- convPackingTest,
+ uniConvTest,
::testing::Combine(
::testing::ValuesIn({1, 2}), // MB
::testing::ValuesIn({16, 32}), // IC
@@ -47,7 +47,7 @@ INSTANTIATE_TEST_CASE_P(
/**
* Test for conv packing
*/
-TEST_P(convPackingTest, packingTest) {
+TEST_P(uniConvTest, packingTest) {
int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
@@ -146,3 +146,60 @@ TEST_P(convPackingTest, packingTest) {
}
}
}
+
+/**
+ * Test for packing/unpacking
+ */
+TEST_P(uniConvTest, packUnpackTest) {
+ int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
+ tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
+
+ conv_param_t<2> conv_p_2d(
+ MB,
+ IC,
+ OC,
+ {IH, IW},
+ G,
+ {kernel, kernel},
+ {stride, stride},
+ {pad, pad, pad, pad});
+
+ int kernel_dim_2d = kernel * kernel;
+
+ aligned_vector<int8_t> Bint8_2d(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+ aligned_vector<int8_t> Bint8_2d_unpacked(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+
+ PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
+
+ packedB_2D.unpack(Bint8_2d_unpacked.data());
+
+ ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked)
+ << "Original and unpacked data elements are not the same [2D]";
+
+ conv_param_t<3> conv_p_3d(
+ MB,
+ IC,
+ OC,
+ {IT, IH, IW},
+ G,
+ {kernel, kernel, kernel},
+ {stride, stride, stride},
+ {pad, pad, pad, pad, pad, pad});
+
+ int kernel_dim_3d = kernel * kernel * kernel;
+
+ aligned_vector<int8_t> Bint8_3d(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+
+ aligned_vector<int8_t> Bint8_3d_unpacked(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+
+ PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());
+
+ packedB_3D.unpack(Bint8_3d_unpacked.data());
+
+ ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked)
+ << "Original and unpacked data elements are not the same [3D]";
+}