From 8392eca198742b949529e18619a7ec9a25f4b399 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Thu, 15 Nov 2018 18:21:15 -0800 Subject: grouped (batched) gemm (#7) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/7 This diff allows groups > 1 . Will have a separate diff for im2col + gemm fusion and conv with group > 1 . Reviewed By: jianyuh Differential Revision: D13039210 fbshipit-source-id: f7b3b0dbdb67fc6bc865de88292f034b252d029d --- src/ExecuteKernelU8S8.cc | 55 +++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 22 deletions(-) (limited to 'src/ExecuteKernelU8S8.cc') diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index b3f8c15..c2079b1 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -8,7 +8,6 @@ #include #include - #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN double kernel_time = 0.0; double postprocessing_time = 0.0; @@ -84,8 +83,10 @@ void ExecuteKernel< int32_t packed_rows_A = packedA_.numPackedRows(); int32_t row_start_A = packedA_.packedRowStart(); - bool lastKBlock = packedB_.isThisLastKBlock(kBlock); - bool accum = kBlock > 0; + int group = kBlock / packedB_.blockRows(); + int NDim = packedB_.numCols(); + bool lastKBlock = packedB_.isThisLastKBlock(kBlock % packedB_.blockRows()); + bool accum = (kBlock % packedB_.blockRows()) > 0; typename BaseType::jit_micro_kernel_fp fn; @@ -120,7 +121,6 @@ void ExecuteKernel< #endif for (int jb = 0; jb < bColBlocks; ++jb) { - bBuf = packedB_.getBuf(jb, kBlock); // prefetch addr of the next packed block of B matrix bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock); @@ -128,12 +128,14 @@ void ExecuteKernel< // Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as // matC_ (inplace output processing) int32_t* C_buffer_row_start = C_buffer_ + - ((C_buffer_ == reinterpret_cast(matC_)) ? row_start_A * ldc_ - : 0); + ((C_buffer_ == reinterpret_cast(matC_)) + ? row_start_A * ldc_ + NDim * group + : 0); int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_; int32_t leadingDim = ldc_; if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) { - // In case we will access memory past C_buffer_, we use C_tile_ instead. + // In case we will access memory past C_buffer_, we use C_tile_ scratchpad + // instead. C_buffer_start = C_tile_; leadingDim = nbSize_; } @@ -146,14 +148,15 @@ void ExecuteKernel< leadingDim); #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN - t_end = std::chrono::high_resolution_clock::now(); - dt = std::chrono::duration_cast(t_end - t_start) - .count(); - kernel_time += (dt); - t_start = std::chrono::high_resolution_clock::now(); + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + kernel_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); #endif - // Output processing is done only once per rowblock + // Output processing is done only once per rowblock to amortize overhead + // and for better spatial locality. if (lastKBlock && jb == bColBlocks - 1) { // When C_tile_ is used for the last column block, we need a separate // handling for the last column block. @@ -166,14 +169,14 @@ void ExecuteKernel< outputProcess_.template f( matC_, C_buffer_row_start, - {row_start_A, packed_rows_A, 0, nSize}, + {row_start_A, packed_rows_A, NDim * group, nSize}, ldc_, ldc_); } else if (cpuinfo_has_x86_avx2()) { outputProcess_.template f( matC_, C_buffer_row_start, - {row_start_A, packed_rows_A, 0, nSize}, + {row_start_A, packed_rows_A, NDim * group, nSize}, ldc_, ldc_); } else { @@ -183,20 +186,28 @@ void ExecuteKernel< } if (C_buffer_start == C_tile_) { + // When C_tile_ scratchpad was used to avoid accessing memory past + // C_buffer_ . if (cpuinfo_has_x86_avx512f()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f( matC_, C_tile_, - {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + {row_start_A, + packed_rows_A, + NDim * group + jb * nbSize_, + packedB_.lastBcol()}, ldc_, leadingDim); } else if (cpuinfo_has_x86_avx2()) { outputProcess_.template f( matC_, C_tile_, - {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + {row_start_A, + packed_rows_A, + NDim * group + jb * nbSize_, + packedB_.lastBcol()}, ldc_, leadingDim); } else { @@ -207,11 +218,11 @@ void ExecuteKernel< } // output processing #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN - t_end = std::chrono::high_resolution_clock::now(); - dt = std::chrono::duration_cast(t_end - t_start) - .count(); - postprocessing_time += (dt); - t_start = std::chrono::high_resolution_clock::now(); + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast(t_end - t_start) + .count(); + postprocessing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); #endif } // for each j block -- cgit v1.2.3