diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-06-04 06:31:51 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-06-04 06:36:23 +0300 |
commit | 77868418c7963572167690ef069b06cbfe67de1f (patch) | |
tree | 59745ce716e94d6095b678edb4f6b2a4fdf7dfea | |
parent | 85f4105cebfe538905dd167db1e2355f9637d8a1 (diff) |
Add quantized::fbgemm_linear_unpack operator for serialization (#97)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/97
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20721
- FBGEMM: Add unpack function for PackBMatrix class: Unpack pmat buffer to the origin_buf (Used for the serialization to recover weight matrix).
- PyTorch Quantizer: Add quantized::fbgemm_linear_unpack operator for serialization.
Reviewed By: zafartahirov
Differential Revision: D15314568
fbshipit-source-id: 12080c8887ce31dc849d23e132ae1766ac319407
-rw-r--r-- | include/fbgemm/Fbgemm.h | 13 | ||||
-rw-r--r-- | src/PackBMatrix.cc | 55 |
2 files changed, 66 insertions, 2 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 48f7255..720f681 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -216,6 +216,13 @@ class PackMatrix { } /** + * @return The first column of the block we're working on. + */ + std::int32_t packedColStart() const { + return packedBlock_.col_start; + } + + /** * @return The beginning of (rowBlockNum, colBlockNum)th block */ inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) { @@ -451,6 +458,12 @@ class FBGEMM_API PackBMatrix final */ bool equals(const PackBMatrix<T, accT>& that) const; + /** + * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to + * recover weight matrix). + */ + void unpack(T* origin_buf); + ~PackBMatrix() {} private: diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 970a741..b6d06ca 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -241,9 +241,9 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { BaseType::packedBlock(block); bool tr = (trans_ == matrix_op_t::Transpose); - for (int g = 0; g < this->numGroups(); ++g) { + for (int g = 0; g < BaseType::numGroups(); ++g) { T* out = BaseType::getBuf() + - g * this->packedBufferSize(block.row_size, block.col_size); + g * BaseType::packedBufferSize(block.row_size, block.col_size); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * (BaseType::blockRowSize() * BaseType::blockColSize()) + @@ -304,6 +304,57 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { } template <typename T, typename accT> +void PackBMatrix<T, accT>::unpack(T* origin_buf) { + bool tr = (trans_ == matrix_op_t::Transpose); + for (int g = 0; g < this->numGroups(); ++g) { + T* out = BaseType::getBuf() + + g * + BaseType::packedBufferSize( + BaseType::numPackedRows(), BaseType::numPackedCols()); + for (int i = BaseType::packedRowStart(); + i < BaseType::packedRowStart() + BaseType::numPackedRows(); + ++i) { + int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()) + + (i % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + i % row_interleave_; + + int c_start_offset = + (BaseType::packedColStart() / BaseType::blockColSize()) * + BaseType::blockRowSize() * BaseType::blockColSize() + + (BaseType::packedColStart() % BaseType::blockColSize()) * + row_interleave_; + + int c_idx_offset = 0; + int c_blk_offset = 0; + for (int j = BaseType::packedColStart(); + j < BaseType::packedColStart() + BaseType::numPackedCols(); + ++j) { + int c_offset = c_start_offset + + c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + + c_idx_offset * row_interleave_; + + int out_idx = r_offset + c_offset; + + T val = out[out_idx]; + if (tr) { + origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val; + } else { + origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val; + } + + c_idx_offset++; + if (c_idx_offset == BaseType::blockColSize()) { + c_idx_offset = 0; + c_blk_offset++; + } + } + } + } // for each group +} + +template <typename T, typename accT> int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { int32_t block_row_id = r / BaseType::blockRowSize(); int32_t brow_offset = (block_row_id * BaseType::blockCols()) * |