Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-10-17 20:18:41 +0300
committerGitHub <noreply@github.com>2016-10-17 20:18:41 +0300
commit1e5a315d03c91286d859512574d3b0b25e12d512 (patch)
tree90e1d80b774a62de10267954ad3940741f8a7b00
parent4f7843e8be8de37d0474e9d4a529261b147e8a8e (diff)
parent983ba05eccfde1fe8662e9a242a9a3202e7c8c31 (diff)
Merge pull request #800 from colesbury/openmp
Expose OpenMP num threads through TH lib
-rw-r--r--lib/TH/THGeneral.c29
-rw-r--r--lib/TH/THGeneral.h.in3
-rw-r--r--utils.c21
3 files changed, 35 insertions, 18 deletions
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c
index 399403b..486a667 100644
--- a/lib/TH/THGeneral.c
+++ b/lib/TH/THGeneral.c
@@ -1,6 +1,10 @@
#include "THGeneral.h"
#include "THAtomic.h"
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
#ifndef TH_HAVE_THREAD
#define __thread
#elif _MSC_VER
@@ -314,3 +318,28 @@ double THLog1p(const double x)
return log1p(x);
#endif
}
+
+void THSetNumThreads(int num_threads)
+{
+#ifdef _OPENMP
+ omp_set_num_threads(num_threads);
+#endif
+}
+
+int THGetNumThreads()
+{
+#ifdef _OPENMP
+ return omp_get_max_threads();
+#else
+ return 1;
+#endif
+}
+
+int THGetNumCores()
+{
+#ifdef _OPENMP
+ return omp_get_num_procs();
+#else
+ return 1;
+#endif
+}
diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in
index ff41159..02c3832 100644
--- a/lib/TH/THGeneral.h.in
+++ b/lib/TH/THGeneral.h.in
@@ -64,6 +64,9 @@ TH_API void THFree(void *ptr);
TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data );
// this hook should only be called by custom allocator functions
TH_API void THHeapUpdate(ptrdiff_t size);
+TH_API void THSetNumThreads(int num_threads);
+TH_API int THGetNumThreads();
+TH_API int THGetNumCores();
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
diff --git a/utils.c b/utils.c
index eb7ff53..894bb6e 100644
--- a/utils.c
+++ b/utils.c
@@ -7,10 +7,6 @@
# include <sys/time.h>
#endif
-#ifdef _OPENMP
-#include <omp.h>
-#endif
-
THLongStorage* torch_checklongargs(lua_State *L, int index)
{
THLongStorage *storage;
@@ -171,30 +167,19 @@ const char* torch_getdefaulttensortype(lua_State *L)
static int torch_getnumthreads(lua_State *L)
{
-#ifdef _OPENMP
- lua_pushinteger(L, omp_get_max_threads());
-#else
- lua_pushinteger(L, 1);
-#endif
+ lua_pushinteger(L, THGetNumThreads());
return 1;
}
static int torch_setnumthreads(lua_State *L)
{
-#ifdef _OPENMP
- int nth = luaL_checkint(L,1);
- omp_set_num_threads(nth);
-#endif
+ THSetNumThreads(luaL_checkint(L, 1));
return 0;
}
static int torch_getnumcores(lua_State *L)
{
-#ifdef _OPENMP
- lua_pushinteger(L, omp_get_num_procs());
-#else
- lua_pushinteger(L, 1);
-#endif
+ lua_pushinteger(L, THGetNumCores());
return 1;
}