diff options
Diffstat (limited to 'extern/cuew/auto/cuew_gen.py')
-rw-r--r-- | extern/cuew/auto/cuew_gen.py | 80 |
1 files changed, 62 insertions, 18 deletions
diff --git a/extern/cuew/auto/cuew_gen.py b/extern/cuew/auto/cuew_gen.py index 75e5bf876f4..6cc48e4f809 100644 --- a/extern/cuew/auto/cuew_gen.py +++ b/extern/cuew/auto/cuew_gen.py @@ -49,7 +49,7 @@ COPYRIGHT = """/* * See the License for the specific language governing permissions and * limitations under the License */""" -FILES = ["cuda.h", "cudaGL.h"] +FILES = ["cuda.h", "cudaGL.h", 'nvrtc.h'] TYPEDEFS = [] FUNC_TYPEDEFS = [] @@ -89,7 +89,10 @@ class FuncDefVisitor(c_ast.NodeVisitor): self.indent -= 1 return "union {\n" + union + (" " * self.indent) + "}" elif isinstance(node, c_ast.Enum): - return 'enum ' + node.name + if node.name is not None: + return 'enum ' + node.name + else: + return 'enum ' elif isinstance(node, c_ast.TypeDecl): return self._get_ident_type(node.type) else: @@ -268,7 +271,9 @@ def parse_files(): token = line.split() if token[0] not in ("__cuda_cuda_h__", "CUDA_CB", - "CUDAAPI"): + "CUDAAPI", + "CUDAGL_H", + "__NVRTC_H__"): DEFINES.append(token) for line in lines: @@ -403,7 +408,7 @@ def print_dl_wrapper(): typedef HMODULE DynamicLibrary; -# define dynamic_library_open(path) LoadLibrary(path) +# define dynamic_library_open(path) LoadLibraryA(path) # define dynamic_library_close(lib) FreeLibrary(lib) # define dynamic_library_find(lib, symbol) GetProcAddress(lib, symbol) #else @@ -419,23 +424,44 @@ typedef void* DynamicLibrary; def print_dl_helper_macro(): - print("""#define %s_LIBRARY_FIND_CHECKED(name) \\ + print("""#define _LIBRARY_FIND_CHECKED(lib, name) \\ name = (t##name *)dynamic_library_find(lib, #name); \\ assert(name); -#define %s_LIBRARY_FIND(name) \\ +#define _LIBRARY_FIND(lib, name) \\ name = (t##name *)dynamic_library_find(lib, #name); -static DynamicLibrary lib;""" % (REAL_LIB, REAL_LIB)) +#define %s_LIBRARY_FIND_CHECKED(name) \\ + _LIBRARY_FIND_CHECKED(cuda_lib, name) +#define %s_LIBRARY_FIND(name) _LIBRARY_FIND(cuda_lib, name) + +#define NVRTC_LIBRARY_FIND_CHECKED(name) \\ + _LIBRARY_FIND_CHECKED(nvrtc_lib, name) +#define NVRTC_LIBRARY_FIND(name) _LIBRARY_FIND(nvrtc_lib, name) + +static DynamicLibrary cuda_lib; +static DynamicLibrary nvrtc_lib;""" % (REAL_LIB, REAL_LIB)) print("") -def print_dl_close(): - print("""static void %sExit(void) { - if(lib != NULL) { +def print_dl_helpers(): + print("""static DynamicLibrary dynamic_library_open_find(const char **paths) { + int i = 0; + while (paths[i] != NULL) { + DynamicLibrary lib = dynamic_library_open(paths[i]); + if (lib != NULL) { + return lib; + } + ++i; + } + return NULL; +} + +static void %sExit(void) { + if(cuda_lib != NULL) { /* Ignore errors. */ - dynamic_library_close(lib); - lib = NULL; + dynamic_library_close(cuda_lib); + cuda_lib = NULL; } }""" % (LIB.lower())) print("") @@ -445,12 +471,21 @@ def print_lib_path(): # TODO(sergey): get rid of hardcoded libraries. print("""#ifdef _WIN32 /* Expected in c:/windows/system or similar, no path needed. */ - const char *path = "nvcuda.dll"; + const char *cuda_paths[] = {"nvcuda.dll", NULL}; + const char *nvrtc_paths[] = {"nvrtc.dll", NULL}; #elif defined(__APPLE__) /* Default installation path. */ - const char *path = "/usr/local/cuda/lib/libcuda.dylib"; + const char *cuda_paths[] = {"/usr/local/cuda/lib/libcuda.dylib", NULL}; + const char *nvrtc_paths[] = {"/usr/local/cuda/lib/libnvrtc.dylib", NULL}; #else - const char *path = "libcuda.so"; + const char *cuda_paths[] = {"libcuda.so", NULL}; + const char *nvrtc_paths[] = {"libnvrtc.so", +# if defined(__x86_64__) || defined(_M_X64) + "/usr/local/cuda/lib64/libnvrtc.so", +#else + "/usr/local/cuda/lib/libnvrtc.so", +#endif + NULL}; #endif""") @@ -472,9 +507,11 @@ def print_init_guard(): } /* Load library. */ - lib = dynamic_library_open(path); + cuda_lib = dynamic_library_open_find(cuda_paths); + nvrtc_lib = dynamic_library_open_find(nvrtc_paths); - if (lib == NULL) { + /* CUDA library is mandatory to have, while nvrtc might be missing. */ + if (cuda_lib == NULL) { result = CUEW_ERROR_OPEN_FAILED; return result; }""") @@ -509,10 +546,17 @@ def print_dl_init(): print(" /* Fetch all function pointers. */") for symbol in SYMBOLS: if symbol: + if not symbol.startswith('nvrtc'): print(" %s_LIBRARY_FIND(%s);" % (REAL_LIB, symbol)) else: print("") + print(" if (nvrtc_lib != NULL) {") + for symbol in SYMBOLS: + if symbol and symbol.startswith('nvrtc'): + print(" NVRTC_LIBRARY_FIND(%s);" % (symbol)) + print(" }") + print("") print(" result = CUEW_SUCCESS;") print(" return result;") @@ -549,7 +593,7 @@ def print_implementation(): print("") print("") - print_dl_close() + print_dl_helpers() print("/* Implementation function. */") print_dl_init() |