diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-15 20:43:29 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-15 20:46:23 +0300 |
commit | 83b462ec58e78311ba3b64b48cbd49db0485641f (patch) | |
tree | 92532e9f70e5c6be5ff3a5e0ab646548f15b7931 /src/FbgemmI8Spmdm.cc | |
parent | 05ce78e3a5735217cb9154a2c1572dc956ffe6fc (diff) |
simple spmdm optimization (#76)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/76
Create a temp buffer for accumulating results instead of directly accessing C matrix with strides.
This speeds up hyper-sparse case implemented w/o transpose so we adjust the threshold between the implementation w/o transpose and w/ transpose accordingly.
Reviewed By: jianyuh
Differential Revision: D14097154
fbshipit-source-id: 22e37d0a9f38ccb3d15813edcd96f3d341eacf1c
Diffstat (limited to 'src/FbgemmI8Spmdm.cc')
-rw-r--r-- | src/FbgemmI8Spmdm.cc | 78 |
1 files changed, 59 insertions, 19 deletions
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc index c249871..10e5a1b 100644 --- a/src/FbgemmI8Spmdm.cc +++ b/src/FbgemmI8Spmdm.cc @@ -34,14 +34,14 @@ CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols) old_nnz_(-1) {} double CompressedSparseColumn::Density() const { - return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols()); + return static_cast<double>(NumOfNonZeros()) / (NumOfRows() * NumOfCols()); } bool CompressedSparseColumn::IsHyperSparse() const { if (NumOfNonZeros() != old_nnz_) { old_nnz_ = NumOfNonZeros(); // The number of non-zero per row is very small. - hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08; + hyper_sparse_ = static_cast<double>(old_nnz_) / NumOfRows() < 0.3; } return hyper_sparse_; @@ -82,25 +82,65 @@ void CompressedSparseColumn::SpMDM( // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. // If NNZ/K is small, it's not worth doing transpose so we just use this // scalar loop. - if (!accumulation) { - for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { - for (int j = block.col_start; j < block.col_start + block.col_size; - ++j) { - C[(i - block.row_start) * ldc + j - block.col_start] = 0; + int32_t C_temp[block.row_size]; + if (accumulation) { + for (int j = 0; j < block.col_size; ++j) { + int k = colptr_[block.col_start + j]; + int k_end = colptr_[block.col_start + j + 1]; + if (k_end == k) { + } else if (k_end == k + 1) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] += A[(block.row_start + i) * lda + row] * w; + } + } else { + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] = C[i * ldc + j]; + } + for (; k < k_end; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] += A[(block.row_start + i) * lda + row] * w; + } + } + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = C_temp[i]; + } } - } - } - for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { - for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) { - int row = rowidx_[k]; - int w = values_[k]; - for (int i = block.row_start; i < block.row_start + block.row_size; - ++i) { - C[(i - block.row_start) * ldc + j - block.col_start] += - A[i * lda + row] * w; + } // for each column of B + } else { + for (int j = 0; j < block.col_size; ++j) { + int k = colptr_[block.col_start + j]; + int k_end = colptr_[block.col_start + j + 1]; + if (k_end == k) { + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = 0; + } + } else if (k_end == k + 1) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = A[(block.row_start + i) * lda + row] * w; + } + } else { + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] = 0; + } + for (; k < k_end; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] += A[(block.row_start + i) * lda + row] * w; + } + } + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = C_temp[i]; + } } - } - } // for each column of B + } // for each column of B + } return; } |