Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2018-11-16 05:21:15 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-16 05:24:41 +0300
commit8392eca198742b949529e18619a7ec9a25f4b399 (patch)
treeda6f827a09b6a90d6294ca566f073ff387225a3b /src/ExecuteKernelU8S8.cc
parentaa5b56a0de8b8f19047a9d797bf2d70cca4bc1f0 (diff)
grouped (batched) gemm (#7)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/7 This diff allows groups > 1 . Will have a separate diff for im2col + gemm fusion and conv with group > 1 . Reviewed By: jianyuh Differential Revision: D13039210 fbshipit-source-id: f7b3b0dbdb67fc6bc865de88292f034b252d029d
Diffstat (limited to 'src/ExecuteKernelU8S8.cc')
-rw-r--r--src/ExecuteKernelU8S8.cc55
1 files changed, 33 insertions, 22 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index b3f8c15..c2079b1 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -8,7 +8,6 @@
#include <cpuinfo.h>
#include <chrono>
-
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
double kernel_time = 0.0;
double postprocessing_time = 0.0;
@@ -84,8 +83,10 @@ void ExecuteKernel<
int32_t packed_rows_A = packedA_.numPackedRows();
int32_t row_start_A = packedA_.packedRowStart();
- bool lastKBlock = packedB_.isThisLastKBlock(kBlock);
- bool accum = kBlock > 0;
+ int group = kBlock / packedB_.blockRows();
+ int NDim = packedB_.numCols();
+ bool lastKBlock = packedB_.isThisLastKBlock(kBlock % packedB_.blockRows());
+ bool accum = (kBlock % packedB_.blockRows()) > 0;
typename BaseType::jit_micro_kernel_fp fn;
@@ -120,7 +121,6 @@ void ExecuteKernel<
#endif
for (int jb = 0; jb < bColBlocks; ++jb) {
-
bBuf = packedB_.getBuf(jb, kBlock);
// prefetch addr of the next packed block of B matrix
bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock);
@@ -128,12 +128,14 @@ void ExecuteKernel<
// Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as
// matC_ (inplace output processing)
int32_t* C_buffer_row_start = C_buffer_ +
- ((C_buffer_ == reinterpret_cast<int32_t*>(matC_)) ? row_start_A * ldc_
- : 0);
+ ((C_buffer_ == reinterpret_cast<int32_t*>(matC_))
+ ? row_start_A * ldc_ + NDim * group
+ : 0);
int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_;
int32_t leadingDim = ldc_;
if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) {
- // In case we will access memory past C_buffer_, we use C_tile_ instead.
+ // In case we will access memory past C_buffer_, we use C_tile_ scratchpad
+ // instead.
C_buffer_start = C_tile_;
leadingDim = nbSize_;
}
@@ -146,14 +148,15 @@ void ExecuteKernel<
leadingDim);
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
- t_end = std::chrono::high_resolution_clock::now();
- dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
- .count();
- kernel_time += (dt);
- t_start = std::chrono::high_resolution_clock::now();
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ kernel_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
#endif
- // Output processing is done only once per rowblock
+ // Output processing is done only once per rowblock to amortize overhead
+ // and for better spatial locality.
if (lastKBlock && jb == bColBlocks - 1) {
// When C_tile_ is used for the last column block, we need a separate
// handling for the last column block.
@@ -166,14 +169,14 @@ void ExecuteKernel<
outputProcess_.template f<inst_set_t::avx2>(
matC_,
C_buffer_row_start,
- {row_start_A, packed_rows_A, 0, nSize},
+ {row_start_A, packed_rows_A, NDim * group, nSize},
ldc_,
ldc_);
} else if (cpuinfo_has_x86_avx2()) {
outputProcess_.template f<inst_set_t::avx2>(
matC_,
C_buffer_row_start,
- {row_start_A, packed_rows_A, 0, nSize},
+ {row_start_A, packed_rows_A, NDim * group, nSize},
ldc_,
ldc_);
} else {
@@ -183,20 +186,28 @@ void ExecuteKernel<
}
if (C_buffer_start == C_tile_) {
+ // When C_tile_ scratchpad was used to avoid accessing memory past
+ // C_buffer_ .
if (cpuinfo_has_x86_avx512f()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
matC_,
C_tile_,
- {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ {row_start_A,
+ packed_rows_A,
+ NDim * group + jb * nbSize_,
+ packedB_.lastBcol()},
ldc_,
leadingDim);
} else if (cpuinfo_has_x86_avx2()) {
outputProcess_.template f<inst_set_t::avx2>(
matC_,
C_tile_,
- {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ {row_start_A,
+ packed_rows_A,
+ NDim * group + jb * nbSize_,
+ packedB_.lastBcol()},
ldc_,
leadingDim);
} else {
@@ -207,11 +218,11 @@ void ExecuteKernel<
} // output processing
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
- t_end = std::chrono::high_resolution_clock::now();
- dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
- .count();
- postprocessing_time += (dt);
- t_start = std::chrono::high_resolution_clock::now();
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ postprocessing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
#endif
} // for each j block