Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2018-11-20 10:31:33 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-20 10:33:34 +0300
commit12c8362ef738879c971ea4d411f419279107057e (patch)
treea6f2415d66e166d0f26b827a045e074a327d52a2 /include/fbgemm/Fbgemm.h
parent282afa288cf45f75cfc38010d07ce77c081729c9 (diff)
Optimize parallelization performance (#15)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/15 Better load balance the workload among different threads. Reviewed By: jspark1105 Differential Revision: D13108873 fbshipit-source-id: ae75971b5ff2cc7cf19907eb95cf2df071f7bbe3
Diffstat (limited to 'include/fbgemm/Fbgemm.h')
-rw-r--r--include/fbgemm/Fbgemm.h57
1 files changed, 57 insertions, 0 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 9cb87ea..d234545 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -1024,4 +1024,61 @@ static void* fbgemmAlignedAlloc(size_t __align, size_t __size) {
*/
bool fbgemmSupportedCPU();
+/*
+ * @brief Partition the workload between 0 and m into num_threads segments. Each
+ * thread gets a multiple of mr, except that the last one might receive the
+ * fringe case. Return the start and end index of each thread.
+ * Example: mr = 8
+ * m mRegBlocks mRegRemainder num_thread_left _right th0 th1 th2
+ * 120 15 0 0 3 40 40 40
+ * 123 15 3 0 3 40 40 43
+ * 133 16 5 1 2 48 40 45
+ * 140 17 4 2 1 48 48 44
+ * 146 18 2 0 3 48 48 50
+ * 144 18 0 0 3 48 48 48
+ *
+ * ToDo: Make this routine more general: partition the workload between any
+ * intervals. We can then reuse this routine for the nested parallel workload
+ * distribution.
+ */
+static void fbgemmGetRange(
+ int num_threads,
+ int thread_id,
+ int m,
+ int mr,
+ int& start,
+ int& end) {
+ int mRegBlocks = m / mr;
+ int mRegRemainder = m % mr;
+
+ int m_blk_per_thread = mRegBlocks / num_threads;
+
+ int num_thread_left = mRegBlocks % num_threads;
+ // int num_thread_right = num_threads - num_thread_left;
+
+ int m_blk_left, m_blk_right;
+ if (num_thread_left == 0) {
+ m_blk_left = m_blk_per_thread;
+ m_blk_right = m_blk_per_thread;
+ } else {
+ m_blk_left = m_blk_per_thread + 1;
+ m_blk_right = m_blk_per_thread;
+ }
+
+ int size_left = m_blk_left * mr;
+ int size_right = m_blk_right * mr;
+
+ if (thread_id < num_thread_left) {
+ start = 0 + thread_id * size_left;
+ end = 0 + (thread_id + 1) * size_left;
+ } else { // thread_id >= num_thread_left
+ start = num_thread_left * size_left +
+ (thread_id - num_thread_left) * size_right;
+ end = num_thread_left * size_left +
+ (thread_id - num_thread_left + 1) * size_right;
+ if (thread_id == num_threads - 1)
+ end += mRegRemainder;
+ }
+}
+
} // namespace fbgemm