Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2018-11-23 08:46:38 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-23 08:48:35 +0300
commitee81a57c087a3914aa70ba0d42895f21df2596f2 (patch)
tree7b494fdebd111ab100c16c65645910d187b42592 /src/Fbgemm.cc
parent719734d01655b7ec5837adaa2710d4b7d03c0840 (diff)
parallelization over groups (#23)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/23 First parallelize over group and then parallelize within each group. Reviewed By: jianyuh Differential Revision: D13166764 fbshipit-source-id: 58da644ec5fbd5d6e3e87d46790b9199dded6889
Diffstat (limited to 'src/Fbgemm.cc')
-rw-r--r--src/Fbgemm.cc36
1 files changed, 30 insertions, 6 deletions
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 99d0a52..0039daf 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -88,14 +88,15 @@ void fbgemmPacked(
if (!packB.isPrePacked()) {
throw std::runtime_error("B matrix must be prepacked");
}
- if (packA.numGroups() != packB.numGroups()) {
+ int G = packA.numGroups();
+ if (G != packB.numGroups()) {
throw std::runtime_error(
- "A.groups = " + std::to_string(packA.numGroups()) + " and B.groups = " +
+ "A.groups = " + std::to_string(G) + " and B.groups = " +
std::to_string(packB.numGroups()) + " are not the same");
}
int MDim = packA.numRows();
- int KDimPerGroup = packB.numRows() / packB.numGroups();
+ int KDimPerGroup = packB.numRows() / G;
int kBlocks = (KDimPerGroup + KCB - 1) / KCB;
@@ -114,10 +115,33 @@ void fbgemmPacked(
t_very_start = std::chrono::high_resolution_clock::now();
#endif
- for (int g = 0; g < packA.numGroups(); ++g) {
- int i_begin, i_end;
- fbgemmGetRange(num_threads, thread_id, MDim, MR, i_begin, i_end);
+ int g_begin, g_end, i_begin, i_end;
+ if (G >= num_threads) {
+ // When G >= nthreads, just parallelize over G
+ // TODO: when G == nthreads + 1, we'll have a big load imbalance because
+ // only one thread will get 2 groups.
+ fbgemmGetRange(num_threads, thread_id, G, 1, g_begin, g_end);
+ i_begin = 0;
+ i_end = MDim;
+ } else {
+ // Otherwise, each group is parallelized by multiple threads.
+ // nthreads_per_group is floor(nthreads / G).
+ // If we use ceil, some groups won't be handled by any thread.
+ int nthreads_per_group = num_threads / G;
+ g_begin = std::max(std::min(thread_id / nthreads_per_group, G - 1), 0);
+ g_end = std::min(g_begin + 1, G);
+
+ int tid_of_g_begin = std::min(g_begin * nthreads_per_group, num_threads);
+ int tid_of_g_end = std::min(
+ (g_end == G) ? num_threads : (tid_of_g_begin + nthreads_per_group),
+ num_threads);
+ int nthreads_within_group = tid_of_g_end - tid_of_g_begin;
+ int tid_within_group = thread_id - tid_of_g_begin;
+ fbgemmGetRange(
+ nthreads_within_group, tid_within_group, MDim, MR, i_begin, i_end);
+ }
+ for (int g = g_begin; g < g_end; ++g) {
ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType>
exeKernelObj(
packA,