#ifndef THC_GENERAL_INC #define THC_GENERAL_INC #include "THGeneral.h" #include "THAllocator.h" #undef log1p #include "cuda.h" #include "cuda_runtime.h" #include "cublas_v2.h" #cmakedefine USE_MAGMA #ifdef __cplusplus # define THC_EXTERNC extern "C" #else # define THC_EXTERNC extern #endif #ifdef _WIN32 # ifdef THC_EXPORTS # define THC_API THC_EXTERNC __declspec(dllexport) # else # define THC_API THC_EXTERNC __declspec(dllimport) # endif #else # define THC_API THC_EXTERNC #endif #ifndef THAssert #define THAssert(exp) \ do { \ if (!(exp)) { \ _THError(__FILE__, __LINE__, "assert(%s) failed", #exp); \ } \ } while(0) #endif struct THCRNGState; /* Random number generator state. */ typedef struct _THCDeviceAllocator { cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t); cudaError_t (*free)(void*, void*); cudaError_t (*shutdown)(void*); void* state; } THCDeviceAllocator; /* Global state to be held in the cutorch table. */ typedef struct THCState THCState; THC_API THCState* THCState_alloc(); THC_API void THCState_free(THCState* state); THC_API void THCudaInit(THCState* state); THC_API void THCudaShutdown(THCState* state); THC_API void THCudaEnablePeerToPeerAccess(THCState* state); /* If device `dev` can access allocations on device `devToAccess`, this will return */ /* 1; otherwise, 0. */ THC_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess); /* Enables or disables allowed p2p access using cutorch copy. If we are */ /* attempting to enable access, throws an error if CUDA cannot enable p2p */ /* access. */ THC_API void THCState_setPeerToPeerAccess(THCState* state, int dev, int devToAccess, int enable); /* By default, direct in-kernel access to memory on remote GPUs is disabled. When set, this allows direct in-kernel access to remote GPUs where GPU/GPU p2p access is enabled and allowed. */ THC_API int THCState_getKernelPeerToPeerAccessEnabled(THCState* state); THC_API void THCState_setKernelPeerToPeerAccessEnabled(THCState* state, int val); THC_API struct cudaDeviceProp* THCState_getCurrentDeviceProperties(THCState* state); THC_API struct THCRNGState* THCState_getRngState(THCState* state); THC_API THAllocator* THCState_getCudaHostAllocator(THCState* state); THC_API THCDeviceAllocator* THCState_getDeviceAllocator(THCState* state); THC_API void THCMagma_init(THCState *state); /* State manipulators and accessors */ THC_API int THCState_getNumDevices(THCState* state); THC_API void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking); THC_API int THCState_getNumStreams(THCState* state); THC_API cudaStream_t THCState_getDeviceStream(THCState *state, int device, int stream); THC_API cudaStream_t THCState_getCurrentStream(THCState *state); THC_API int THCState_getCurrentStreamIndex(THCState *state); THC_API void THCState_setCurrentStreamIndex(THCState *state, int stream); THC_API void THCState_reserveBlasHandles(THCState* state, int numHandles); THC_API int THCState_getNumBlasHandles(THCState* state); THC_API cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle); THC_API cublasHandle_t THCState_getCurrentBlasHandle(THCState *state); THC_API int THCState_getCurrentBlasHandleIndex(THCState *state); THC_API void THCState_setCurrentBlasHandleIndex(THCState *state, int handle); /* For the current device and stream, returns the allocated scratch space */ THC_API void* THCState_getCurrentDeviceScratchSpace(THCState* state); THC_API void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream); THC_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state); THC_API size_t THCState_getDeviceScratchSpaceSize(THCState* state, int device); #define THCudaCheck(err) __THCudaCheck(err, __FILE__, __LINE__) #define THCublasCheck(err) __THCublasCheck(err, __FILE__, __LINE__) THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line); THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line); THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size); THC_API cudaError_t THCudaFree(THCState *state, void *ptr); THC_API void THCSetGCHandler(THCState *state, void (*torchGCHandlerFunction)(void *data), void *data ); THC_API void THCHeapUpdate(THCState *state, long size); #endif