Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-06-04 06:31:51 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-06-04 06:36:23 +0300
commit77868418c7963572167690ef069b06cbfe67de1f (patch)
tree59745ce716e94d6095b678edb4f6b2a4fdf7dfea
parent85f4105cebfe538905dd167db1e2355f9637d8a1 (diff)
Add quantized::fbgemm_linear_unpack operator for serialization (#97)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/97 Pull Request resolved: https://github.com/pytorch/pytorch/pull/20721 - FBGEMM: Add unpack function for PackBMatrix class: Unpack pmat buffer to the origin_buf (Used for the serialization to recover weight matrix). - PyTorch Quantizer: Add quantized::fbgemm_linear_unpack operator for serialization. Reviewed By: zafartahirov Differential Revision: D15314568 fbshipit-source-id: 12080c8887ce31dc849d23e132ae1766ac319407
-rw-r--r--include/fbgemm/Fbgemm.h13
-rw-r--r--src/PackBMatrix.cc55
2 files changed, 66 insertions, 2 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 48f7255..720f681 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -216,6 +216,13 @@ class PackMatrix {
}
/**
+ * @return The first column of the block we're working on.
+ */
+ std::int32_t packedColStart() const {
+ return packedBlock_.col_start;
+ }
+
+ /**
* @return The beginning of (rowBlockNum, colBlockNum)th block
*/
inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) {
@@ -451,6 +458,12 @@ class FBGEMM_API PackBMatrix final
*/
bool equals(const PackBMatrix<T, accT>& that) const;
+ /**
+ * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to
+ * recover weight matrix).
+ */
+ void unpack(T* origin_buf);
+
~PackBMatrix() {}
private:
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 970a741..b6d06ca 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -241,9 +241,9 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
BaseType::packedBlock(block);
bool tr = (trans_ == matrix_op_t::Transpose);
- for (int g = 0; g < this->numGroups(); ++g) {
+ for (int g = 0; g < BaseType::numGroups(); ++g) {
T* out = BaseType::getBuf() +
- g * this->packedBufferSize(block.row_size, block.col_size);
+ g * BaseType::packedBufferSize(block.row_size, block.col_size);
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
(BaseType::blockRowSize() * BaseType::blockColSize()) +
@@ -304,6 +304,57 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
}
template <typename T, typename accT>
+void PackBMatrix<T, accT>::unpack(T* origin_buf) {
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ for (int g = 0; g < this->numGroups(); ++g) {
+ T* out = BaseType::getBuf() +
+ g *
+ BaseType::packedBufferSize(
+ BaseType::numPackedRows(), BaseType::numPackedCols());
+ for (int i = BaseType::packedRowStart();
+ i < BaseType::packedRowStart() + BaseType::numPackedRows();
+ ++i) {
+ int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize()) +
+ (i % BaseType::blockRowSize() / row_interleave_) *
+ BaseType::blockColSize() * row_interleave_ +
+ i % row_interleave_;
+
+ int c_start_offset =
+ (BaseType::packedColStart() / BaseType::blockColSize()) *
+ BaseType::blockRowSize() * BaseType::blockColSize() +
+ (BaseType::packedColStart() % BaseType::blockColSize()) *
+ row_interleave_;
+
+ int c_idx_offset = 0;
+ int c_blk_offset = 0;
+ for (int j = BaseType::packedColStart();
+ j < BaseType::packedColStart() + BaseType::numPackedCols();
+ ++j) {
+ int c_offset = c_start_offset +
+ c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() +
+ c_idx_offset * row_interleave_;
+
+ int out_idx = r_offset + c_offset;
+
+ T val = out[out_idx];
+ if (tr) {
+ origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val;
+ } else {
+ origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val;
+ }
+
+ c_idx_offset++;
+ if (c_idx_offset == BaseType::blockColSize()) {
+ c_idx_offset = 0;
+ c_blk_offset++;
+ }
+ }
+ }
+ } // for each group
+}
+
+template <typename T, typename accT>
int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
int32_t block_row_id = r / BaseType::blockRowSize();
int32_t brow_offset = (block_row_id * BaseType::blockCols()) *