diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-07-10 06:59:38 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-07-10 07:30:41 +0300 |
commit | f08039388abf2fc9908b5086a8c884202355e649 (patch) | |
tree | ec511e374bb30cb08247465e7fa6897a219bd0c4 | |
parent | 815139b1ba72ed43571084f58c7fdf4ca3c991d5 (diff) |
Refactoring unpack weight function (#103)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/103
In the same spirit of D16085552, we do the following in this Diff:
- Refactor the pack/unpack code for PackB: use the same ```pack_unpack_``` function for both ```pack``` and ```unpack``` function.
- Add a unit test.
Reviewed By: dskhudia
Differential Revision: D16160767
fbshipit-source-id: 7fb7006750537b0705a180f2014c786298a1c615
-rw-r--r-- | include/fbgemm/Fbgemm.h | 9 | ||||
-rw-r--r-- | src/PackBMatrix.cc | 117 | ||||
-rw-r--r-- | test/PackedRequantizeAcc16Test.cc | 72 | ||||
-rw-r--r-- | test/PackedRequantizeTest.cc | 70 |
4 files changed, 197 insertions, 71 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 87f0907..9ee25b5 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -472,6 +472,15 @@ class FBGEMM_API PackBMatrix final const T* smat_; std::int32_t ld_; std::int32_t row_interleave_; + + /** + * @brief Internal function performing both pack & unpack + */ + void pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack); }; /** diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 76198ca..497cc2d 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -234,7 +234,11 @@ PackBMatrix<T, accT>::PackBMatrix( } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block) { +void PackBMatrix<T, accT>::pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack) { assert((BaseType::blockRowSize() % row_interleave_) == 0); assert((block.row_start % BaseType::blockRowSize()) == 0); assert((block.col_start % BaseType::blockColSize()) == 0); @@ -242,7 +246,7 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { BaseType::packedBlock(block); bool tr = (trans_ == matrix_op_t::Transpose); for (int g = 0; g < BaseType::numGroups(); ++g) { - T* out = BaseType::getBuf() + + T* pack_buf_cur = pack_buf + g * BaseType::packedBufferSize(block.row_size, block.col_size); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * @@ -268,10 +272,16 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + c_idx_offset * row_interleave_; - int out_idx = r_offset + c_offset; - T val = tr ? smat_[i + (g * block.col_size + j) * ld_] - : smat_[(g * block.row_size + i) * ld_ + j]; - out[out_idx] = val; + if (ispack) { + pack_buf_cur[r_offset + c_offset] = tr + ? unpack_buf[i + (g * block.col_size + j) * ld_] + : unpack_buf[(g * block.row_size + i) * ld_ + j]; + } else { + T* unpack_buf_cur = tr + ? &(unpack_buf[i + (g * block.col_size + j) * ld_]) + : &(unpack_buf[(g * block.row_size + i) * ld_ + j]); + *unpack_buf_cur = pack_buf_cur[r_offset + c_offset]; + } c_idx_offset++; if (c_idx_offset == BaseType::blockColSize()) { @@ -280,78 +290,45 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { } } } - // fill the remaining with zero. - // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. - for (int i = block.row_start + block.row_size; - i < (block.row_start + block.row_size + row_interleave_ - 1) / - row_interleave_ * row_interleave_; - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - for (int j = block.col_start; j < block.col_start + block.col_size; j++) { - int c_offset = (j / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (j % BaseType::blockColSize()) * row_interleave_; + if (ispack) { + // fill the remaining with zero. + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int i = block.row_start + block.row_size; + i < (block.row_start + block.row_size + row_interleave_ - 1) / + row_interleave_ * row_interleave_; + ++i) { + int r_offset = + ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()) + + (i % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + i % row_interleave_; + for (int j = block.col_start; j < block.col_start + block.col_size; + j++) { + int c_offset = (j / BaseType::blockColSize()) * + BaseType::blockRowSize() * BaseType::blockColSize() + + (j % BaseType::blockColSize()) * row_interleave_; - int out_idx = r_offset + c_offset; - out[out_idx] = 0; + int out_idx = r_offset + c_offset; + pack_buf_cur[out_idx] = 0; + } } } } // for each group } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf) { - bool tr = (trans_ == matrix_op_t::Transpose); - for (int g = 0; g < this->numGroups(); ++g) { - T* out = BaseType::getBuf() + - g * - BaseType::packedBufferSize( - BaseType::numPackedRows(), BaseType::numPackedCols()); - for (int i = BaseType::packedRowStart(); - i < BaseType::packedRowStart() + BaseType::numPackedRows(); - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - - int c_start_offset = - (BaseType::packedColStart() / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (BaseType::packedColStart() % BaseType::blockColSize()) * - row_interleave_; - - int c_idx_offset = 0; - int c_blk_offset = 0; - for (int j = BaseType::packedColStart(); - j < BaseType::packedColStart() + BaseType::numPackedCols(); - ++j) { - int c_offset = c_start_offset + - c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + - c_idx_offset * row_interleave_; - - int out_idx = r_offset + c_offset; - - T val = out[out_idx]; - if (tr) { - origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val; - } else { - origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val; - } +void PackBMatrix<T, accT>::pack(const block_type_t& block) { + pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true); +} - c_idx_offset++; - if (c_idx_offset == BaseType::blockColSize()) { - c_idx_offset = 0; - c_blk_offset++; - } - } - } - } // for each group +template <typename T, typename accT> +void PackBMatrix<T, accT>::unpack(T* origin_buf) { + block_type_t blockB{BaseType::packedRowStart(), + BaseType::numPackedRows(), + BaseType::packedColStart(), + BaseType::numPackedCols()}; + pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false); } template <typename T, typename accT> diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 55f6e7f..40254cb 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -26,7 +26,7 @@ using namespace std; using namespace fbgemm; vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose, - matrix_op_t::Transpose}; + matrix_op_t::Transpose}; vector<QuantizationGranularity> qGranularityVals{ QuantizationGranularity::TENSOR, @@ -39,6 +39,8 @@ class fbgemmu8s8acc16WithQuantGranularityTest tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {}; class fbgemmu8s8acc16Test : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {}; +class fbgemmPackUnpackAcc16Test + : public testing::TestWithParam<tuple<matrix_op_t, bool>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -58,6 +60,11 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(transposeVals), ::testing::Bool())); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmPackUnpackAcc16Test, + ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool())); + /** * @brief Shapes for unit test. */ @@ -809,3 +816,66 @@ TEST_P(fbgemmu8s8acc16Test, NoRequantizeTest) { } // for each groups } // for each shape } + +/** + * @brief Unit test for packing and unpacking the weight tensor. + */ +TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) { + vector<vector<int>> shapes(GetShapes_()); + matrix_op_t btrans; + bool test_ld; + tie(btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + for (int groups : {1, 3, 4}) { + int n = shape[1]; + int k = shape[2]; + + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; + + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t, int16_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups); + + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); + + // Perform unpacking + packedWeights.unpack(unpack_buf.data()); + + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + << "Pack/Unpack results differ at index (" << i << ", " << j + << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) + << ", Pack-Unpacked: " + << static_cast<int>(unpack_buf.data()[i * n + j]); + } + } + } + } +} diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc index fd827b0..11ef6ff 100644 --- a/test/PackedRequantizeTest.cc +++ b/test/PackedRequantizeTest.cc @@ -39,6 +39,8 @@ class fbgemmu8s8acc32WithQuantGranularityTest tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {}; class fbgemmu8s8acc32Test : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {}; +class fbgemmPackUnpackAcc32Test + : public testing::TestWithParam<tuple<matrix_op_t, bool>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -58,6 +60,11 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(transposeVals), ::testing::Bool())); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmPackUnpackAcc32Test, + ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool())); + /** * @brief Shapes for unit test. */ @@ -749,3 +756,66 @@ TEST_P(fbgemmu8s8acc32Test, TestSymmetricQuantizedInputOutput) { } // for each groups } // for each shape } + +/** + * @brief Unit test for packing and unpacking the weight tensor. + */ +TEST_P(fbgemmPackUnpackAcc32Test, TestPackUnpack) { + vector<vector<int>> shapes(GetShapes_()); + matrix_op_t btrans; + bool test_ld; + tie(btrans, test_ld) = GetParam(); + + for (auto shape : shapes) { + for (int groups : {1, 3, 4}) { + int n = shape[1]; + int k = shape[2]; + + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; + + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups); + + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); + + // Perform unpacking + packedWeights.unpack(unpack_buf.data()); + + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + << "Pack/Unpack results differ at index (" << i << ", " << j + << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) + << ", Pack-Unpacked: " + << static_cast<int>(unpack_buf.data()[i * n + j]); + } + } + } + } +} |