diff options
author | Jongsoo Park <jongsoo@fb.com> | 2018-11-23 08:46:38 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-23 08:48:35 +0300 |
commit | ee81a57c087a3914aa70ba0d42895f21df2596f2 (patch) | |
tree | 7b494fdebd111ab100c16c65645910d187b42592 /src/Fbgemm.cc | |
parent | 719734d01655b7ec5837adaa2710d4b7d03c0840 (diff) |
parallelization over groups (#23)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/23
First parallelize over group and then parallelize within each group.
Reviewed By: jianyuh
Differential Revision: D13166764
fbshipit-source-id: 58da644ec5fbd5d6e3e87d46790b9199dded6889
Diffstat (limited to 'src/Fbgemm.cc')
-rw-r--r-- | src/Fbgemm.cc | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 99d0a52..0039daf 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -88,14 +88,15 @@ void fbgemmPacked( if (!packB.isPrePacked()) { throw std::runtime_error("B matrix must be prepacked"); } - if (packA.numGroups() != packB.numGroups()) { + int G = packA.numGroups(); + if (G != packB.numGroups()) { throw std::runtime_error( - "A.groups = " + std::to_string(packA.numGroups()) + " and B.groups = " + + "A.groups = " + std::to_string(G) + " and B.groups = " + std::to_string(packB.numGroups()) + " are not the same"); } int MDim = packA.numRows(); - int KDimPerGroup = packB.numRows() / packB.numGroups(); + int KDimPerGroup = packB.numRows() / G; int kBlocks = (KDimPerGroup + KCB - 1) / KCB; @@ -114,10 +115,33 @@ void fbgemmPacked( t_very_start = std::chrono::high_resolution_clock::now(); #endif - for (int g = 0; g < packA.numGroups(); ++g) { - int i_begin, i_end; - fbgemmGetRange(num_threads, thread_id, MDim, MR, i_begin, i_end); + int g_begin, g_end, i_begin, i_end; + if (G >= num_threads) { + // When G >= nthreads, just parallelize over G + // TODO: when G == nthreads + 1, we'll have a big load imbalance because + // only one thread will get 2 groups. + fbgemmGetRange(num_threads, thread_id, G, 1, g_begin, g_end); + i_begin = 0; + i_end = MDim; + } else { + // Otherwise, each group is parallelized by multiple threads. + // nthreads_per_group is floor(nthreads / G). + // If we use ceil, some groups won't be handled by any thread. + int nthreads_per_group = num_threads / G; + g_begin = std::max(std::min(thread_id / nthreads_per_group, G - 1), 0); + g_end = std::min(g_begin + 1, G); + + int tid_of_g_begin = std::min(g_begin * nthreads_per_group, num_threads); + int tid_of_g_end = std::min( + (g_end == G) ? num_threads : (tid_of_g_begin + nthreads_per_group), + num_threads); + int nthreads_within_group = tid_of_g_end - tid_of_g_begin; + int tid_within_group = thread_id - tid_of_g_begin; + fbgemmGetRange( + nthreads_within_group, tid_within_group, MDim, MR, i_begin, i_end); + } + for (int g = g_begin; g < g_end; ++g) { ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType> exeKernelObj( packA, |