diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-17 20:18:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-17 20:18:41 +0300 |
commit | 1e5a315d03c91286d859512574d3b0b25e12d512 (patch) | |
tree | 90e1d80b774a62de10267954ad3940741f8a7b00 | |
parent | 4f7843e8be8de37d0474e9d4a529261b147e8a8e (diff) | |
parent | 983ba05eccfde1fe8662e9a242a9a3202e7c8c31 (diff) |
Merge pull request #800 from colesbury/openmp
Expose OpenMP num threads through TH lib
-rw-r--r-- | lib/TH/THGeneral.c | 29 | ||||
-rw-r--r-- | lib/TH/THGeneral.h.in | 3 | ||||
-rw-r--r-- | utils.c | 21 |
3 files changed, 35 insertions, 18 deletions
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c index 399403b..486a667 100644 --- a/lib/TH/THGeneral.c +++ b/lib/TH/THGeneral.c @@ -1,6 +1,10 @@ #include "THGeneral.h" #include "THAtomic.h" +#ifdef _OPENMP +#include <omp.h> +#endif + #ifndef TH_HAVE_THREAD #define __thread #elif _MSC_VER @@ -314,3 +318,28 @@ double THLog1p(const double x) return log1p(x); #endif } + +void THSetNumThreads(int num_threads) +{ +#ifdef _OPENMP + omp_set_num_threads(num_threads); +#endif +} + +int THGetNumThreads() +{ +#ifdef _OPENMP + return omp_get_max_threads(); +#else + return 1; +#endif +} + +int THGetNumCores() +{ +#ifdef _OPENMP + return omp_get_num_procs(); +#else + return 1; +#endif +} diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index ff41159..02c3832 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -64,6 +64,9 @@ TH_API void THFree(void *ptr); TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data ); // this hook should only be called by custom allocator functions TH_API void THHeapUpdate(ptrdiff_t size); +TH_API void THSetNumThreads(int num_threads); +TH_API int THGetNumThreads(); +TH_API int THGetNumCores(); #define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__) @@ -7,10 +7,6 @@ # include <sys/time.h> #endif -#ifdef _OPENMP -#include <omp.h> -#endif - THLongStorage* torch_checklongargs(lua_State *L, int index) { THLongStorage *storage; @@ -171,30 +167,19 @@ const char* torch_getdefaulttensortype(lua_State *L) static int torch_getnumthreads(lua_State *L) { -#ifdef _OPENMP - lua_pushinteger(L, omp_get_max_threads()); -#else - lua_pushinteger(L, 1); -#endif + lua_pushinteger(L, THGetNumThreads()); return 1; } static int torch_setnumthreads(lua_State *L) { -#ifdef _OPENMP - int nth = luaL_checkint(L,1); - omp_set_num_threads(nth); -#endif + THSetNumThreads(luaL_checkint(L, 1)); return 0; } static int torch_getnumcores(lua_State *L) { -#ifdef _OPENMP - lua_pushinteger(L, omp_get_num_procs()); -#else - lua_pushinteger(L, 1); -#endif + lua_pushinteger(L, THGetNumCores()); return 1; } |