Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-15 20:43:29 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-15 20:46:23 +0300
commit83b462ec58e78311ba3b64b48cbd49db0485641f (patch)
tree92532e9f70e5c6be5ff3a5e0ab646548f15b7931 /src/FbgemmI8Spmdm.cc
parent05ce78e3a5735217cb9154a2c1572dc956ffe6fc (diff)
simple spmdm optimization (#76)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/76 Create a temp buffer for accumulating results instead of directly accessing C matrix with strides. This speeds up hyper-sparse case implemented w/o transpose so we adjust the threshold between the implementation w/o transpose and w/ transpose accordingly. Reviewed By: jianyuh Differential Revision: D14097154 fbshipit-source-id: 22e37d0a9f38ccb3d15813edcd96f3d341eacf1c
Diffstat (limited to 'src/FbgemmI8Spmdm.cc')
-rw-r--r--src/FbgemmI8Spmdm.cc78
1 files changed, 59 insertions, 19 deletions
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc
index c249871..10e5a1b 100644
--- a/src/FbgemmI8Spmdm.cc
+++ b/src/FbgemmI8Spmdm.cc
@@ -34,14 +34,14 @@ CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols)
old_nnz_(-1) {}
double CompressedSparseColumn::Density() const {
- return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols());
+ return static_cast<double>(NumOfNonZeros()) / (NumOfRows() * NumOfCols());
}
bool CompressedSparseColumn::IsHyperSparse() const {
if (NumOfNonZeros() != old_nnz_) {
old_nnz_ = NumOfNonZeros();
// The number of non-zero per row is very small.
- hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08;
+ hyper_sparse_ = static_cast<double>(old_nnz_) / NumOfRows() < 0.3;
}
return hyper_sparse_;
@@ -82,25 +82,65 @@ void CompressedSparseColumn::SpMDM(
// The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications.
// If NNZ/K is small, it's not worth doing transpose so we just use this
// scalar loop.
- if (!accumulation) {
- for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
- for (int j = block.col_start; j < block.col_start + block.col_size;
- ++j) {
- C[(i - block.row_start) * ldc + j - block.col_start] = 0;
+ int32_t C_temp[block.row_size];
+ if (accumulation) {
+ for (int j = 0; j < block.col_size; ++j) {
+ int k = colptr_[block.col_start + j];
+ int k_end = colptr_[block.col_start + j + 1];
+ if (k_end == k) {
+ } else if (k_end == k + 1) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = 0; i < block.row_size; ++i) {
+ C[i * ldc + j] += A[(block.row_start + i) * lda + row] * w;
+ }
+ } else {
+ for (int i = 0; i < block.row_size; ++i) {
+ C_temp[i] = C[i * ldc + j];
+ }
+ for (; k < k_end; ++k) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = 0; i < block.row_size; ++i) {
+ C_temp[i] += A[(block.row_start + i) * lda + row] * w;
+ }
+ }
+ for (int i = 0; i < block.row_size; ++i) {
+ C[i * ldc + j] = C_temp[i];
+ }
}
- }
- }
- for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
- for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
- int row = rowidx_[k];
- int w = values_[k];
- for (int i = block.row_start; i < block.row_start + block.row_size;
- ++i) {
- C[(i - block.row_start) * ldc + j - block.col_start] +=
- A[i * lda + row] * w;
+ } // for each column of B
+ } else {
+ for (int j = 0; j < block.col_size; ++j) {
+ int k = colptr_[block.col_start + j];
+ int k_end = colptr_[block.col_start + j + 1];
+ if (k_end == k) {
+ for (int i = 0; i < block.row_size; ++i) {
+ C[i * ldc + j] = 0;
+ }
+ } else if (k_end == k + 1) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = 0; i < block.row_size; ++i) {
+ C[i * ldc + j] = A[(block.row_start + i) * lda + row] * w;
+ }
+ } else {
+ for (int i = 0; i < block.row_size; ++i) {
+ C_temp[i] = 0;
+ }
+ for (; k < k_end; ++k) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = 0; i < block.row_size; ++i) {
+ C_temp[i] += A[(block.row_start + i) * lda + row] * w;
+ }
+ }
+ for (int i = 0; i < block.row_size; ++i) {
+ C[i * ldc + j] = C_temp[i];
+ }
}
- }
- } // for each column of B
+ } // for each column of B
+ }
return;
}