From 83b462ec58e78311ba3b64b48cbd49db0485641f Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Fri, 15 Feb 2019 09:43:29 -0800 Subject: simple spmdm optimization (#76) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/76 Create a temp buffer for accumulating results instead of directly accessing C matrix with strides. This speeds up hyper-sparse case implemented w/o transpose so we adjust the threshold between the implementation w/o transpose and w/ transpose accordingly. Reviewed By: jianyuh Differential Revision: D14097154 fbshipit-source-id: 22e37d0a9f38ccb3d15813edcd96f3d341eacf1c --- src/FbgemmI8Spmdm.cc | 78 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 19 deletions(-) (limited to 'src/FbgemmI8Spmdm.cc') diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc index c249871..10e5a1b 100644 --- a/src/FbgemmI8Spmdm.cc +++ b/src/FbgemmI8Spmdm.cc @@ -34,14 +34,14 @@ CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols) old_nnz_(-1) {} double CompressedSparseColumn::Density() const { - return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols()); + return static_cast(NumOfNonZeros()) / (NumOfRows() * NumOfCols()); } bool CompressedSparseColumn::IsHyperSparse() const { if (NumOfNonZeros() != old_nnz_) { old_nnz_ = NumOfNonZeros(); // The number of non-zero per row is very small. - hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08; + hyper_sparse_ = static_cast(old_nnz_) / NumOfRows() < 0.3; } return hyper_sparse_; @@ -82,25 +82,65 @@ void CompressedSparseColumn::SpMDM( // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. // If NNZ/K is small, it's not worth doing transpose so we just use this // scalar loop. - if (!accumulation) { - for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { - for (int j = block.col_start; j < block.col_start + block.col_size; - ++j) { - C[(i - block.row_start) * ldc + j - block.col_start] = 0; + int32_t C_temp[block.row_size]; + if (accumulation) { + for (int j = 0; j < block.col_size; ++j) { + int k = colptr_[block.col_start + j]; + int k_end = colptr_[block.col_start + j + 1]; + if (k_end == k) { + } else if (k_end == k + 1) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] += A[(block.row_start + i) * lda + row] * w; + } + } else { + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] = C[i * ldc + j]; + } + for (; k < k_end; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] += A[(block.row_start + i) * lda + row] * w; + } + } + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = C_temp[i]; + } } - } - } - for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { - for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) { - int row = rowidx_[k]; - int w = values_[k]; - for (int i = block.row_start; i < block.row_start + block.row_size; - ++i) { - C[(i - block.row_start) * ldc + j - block.col_start] += - A[i * lda + row] * w; + } // for each column of B + } else { + for (int j = 0; j < block.col_size; ++j) { + int k = colptr_[block.col_start + j]; + int k_end = colptr_[block.col_start + j + 1]; + if (k_end == k) { + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = 0; + } + } else if (k_end == k + 1) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = A[(block.row_start + i) * lda + row] * w; + } + } else { + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] = 0; + } + for (; k < k_end; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = 0; i < block.row_size; ++i) { + C_temp[i] += A[(block.row_start + i) * lda + row] * w; + } + } + for (int i = 0; i < block.row_size; ++i) { + C[i * ldc + j] = C_temp[i]; + } } - } - } // for each column of B + } // for each column of B + } return; } -- cgit v1.2.3