Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-06-07 00:36:36 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-07 21:53:14 +0300
commit036989cc99b48ef4ecbf7604ab46f8461dd264fb (patch)
treedb2545814910861c7a084d45dde3153090ed04fa
parentc2fa913b6a9bf91051f9b1f09393af1c7b13d7b5 (diff)
Make sure the number of MKL and OpenMP threads match
Otherwise, on many machines, the size of the OpenMP thread pool will change between MKL and our OpenMP enabled functions. The constant thread creation and destruction results in worse performance and leaks memory on GCC 5.4
-rw-r--r--lib/TH/CMakeLists.txt3
-rw-r--r--lib/TH/THGeneral.c15
-rw-r--r--lib/TH/THGeneral.h.in1
3 files changed, 19 insertions, 0 deletions
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt
index bfc17cc..22cdf33 100644
--- a/lib/TH/CMakeLists.txt
+++ b/lib/TH/CMakeLists.txt
@@ -304,6 +304,9 @@ FIND_PACKAGE(BLAS)
IF(BLAS_FOUND)
SET(USE_BLAS 1)
TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES})
+ IF(BLAS_INFO STREQUAL "mkl")
+ ADD_DEFINITIONS(-DTH_BLAS_MKL)
+ ENDIF()
ENDIF(BLAS_FOUND)
FIND_PACKAGE(LAPACK)
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c
index bb9bfc3..d44e762 100644
--- a/lib/TH/THGeneral.c
+++ b/lib/TH/THGeneral.c
@@ -343,3 +343,18 @@ int THGetNumCores(void)
return 1;
#endif
}
+
+#ifdef TH_BLAS_MKL
+extern int mkl_get_max_threads(void);
+#endif
+
+TH_API void THInferNumThreads(void)
+{
+#if defined(_OPENMP) && defined(TH_BLAS_MKL)
+ // If we are using MKL an OpenMP make sure the number of threads match.
+ // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
+ // size of the OpenMP thread pool, resulting in worse performance (and memory
+ // leaks in GCC 5.4)
+ omp_set_num_threads(mkl_get_max_threads());
+#endif
+}
diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in
index de11f1b..0621c7a 100644
--- a/lib/TH/THGeneral.h.in
+++ b/lib/TH/THGeneral.h.in
@@ -60,6 +60,7 @@ TH_API void THHeapUpdate(ptrdiff_t size);
TH_API void THSetNumThreads(int num_threads);
TH_API int THGetNumThreads(void);
TH_API int THGetNumCores(void);
+TH_API void THInferNumThreads(void);
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)