diff options
author | Sam Gross <sgross@fb.com> | 2017-06-07 00:36:36 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-07 21:53:14 +0300 |
commit | 036989cc99b48ef4ecbf7604ab46f8461dd264fb (patch) | |
tree | db2545814910861c7a084d45dde3153090ed04fa | |
parent | c2fa913b6a9bf91051f9b1f09393af1c7b13d7b5 (diff) |
Make sure the number of MKL and OpenMP threads match
Otherwise, on many machines, the size of the OpenMP thread pool will
change between MKL and our OpenMP enabled functions. The constant thread
creation and destruction results in worse performance and leaks memory
on GCC 5.4
-rw-r--r-- | lib/TH/CMakeLists.txt | 3 | ||||
-rw-r--r-- | lib/TH/THGeneral.c | 15 | ||||
-rw-r--r-- | lib/TH/THGeneral.h.in | 1 |
3 files changed, 19 insertions, 0 deletions
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt index bfc17cc..22cdf33 100644 --- a/lib/TH/CMakeLists.txt +++ b/lib/TH/CMakeLists.txt @@ -304,6 +304,9 @@ FIND_PACKAGE(BLAS) IF(BLAS_FOUND) SET(USE_BLAS 1) TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES}) + IF(BLAS_INFO STREQUAL "mkl") + ADD_DEFINITIONS(-DTH_BLAS_MKL) + ENDIF() ENDIF(BLAS_FOUND) FIND_PACKAGE(LAPACK) diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c index bb9bfc3..d44e762 100644 --- a/lib/TH/THGeneral.c +++ b/lib/TH/THGeneral.c @@ -343,3 +343,18 @@ int THGetNumCores(void) return 1; #endif } + +#ifdef TH_BLAS_MKL +extern int mkl_get_max_threads(void); +#endif + +TH_API void THInferNumThreads(void) +{ +#if defined(_OPENMP) && defined(TH_BLAS_MKL) + // If we are using MKL an OpenMP make sure the number of threads match. + // Otherwise, MKL and our OpenMP-enabled functions will keep changing the + // size of the OpenMP thread pool, resulting in worse performance (and memory + // leaks in GCC 5.4) + omp_set_num_threads(mkl_get_max_threads()); +#endif +} diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index de11f1b..0621c7a 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -60,6 +60,7 @@ TH_API void THHeapUpdate(ptrdiff_t size); TH_API void THSetNumThreads(int num_threads); TH_API int THGetNumThreads(void); TH_API int THGetNumCores(void); +TH_API void THInferNumThreads(void); #define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__) |