diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2018-01-10 01:28:19 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2018-01-10 01:28:19 +0300 |
commit | f9275f43ef66a462c67b4fc7efa08201eb61490e (patch) | |
tree | 006cde6b1e3388ea69055f62742175144fee1f77 | |
parent | bbb396c3036d3a22b769f8e670562b9a3db68cf1 (diff) |
move .h -> .cu
-rw-r--r-- | contrib/other-builds/amunmt/.project | 5 | ||||
-rw-r--r-- | src/amun/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/amun/gpu/mblas/handles.cu | 46 | ||||
-rw-r--r-- | src/amun/gpu/mblas/handles.h | 36 |
4 files changed, 59 insertions, 31 deletions
diff --git a/contrib/other-builds/amunmt/.project b/contrib/other-builds/amunmt/.project index c76302c0..8c68f5d2 100644 --- a/contrib/other-builds/amunmt/.project +++ b/contrib/other-builds/amunmt/.project @@ -1561,6 +1561,11 @@ <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.h</locationURI> </link> <link> + <name>src/amun/gpu/mblas/handles.cu</name> + <type>1</type> + <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/handles.cu</locationURI> + </link> + <link> <name>src/amun/gpu/mblas/handles.h</name> <type>1</type> <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/handles.h</locationURI> diff --git a/src/amun/CMakeLists.txt b/src/amun/CMakeLists.txt index 10b8b904..6a3e5d9d 100644 --- a/src/amun/CMakeLists.txt +++ b/src/amun/CMakeLists.txt @@ -92,6 +92,7 @@ cuda_add_executable( gpu/dl4mt/encoder.cu gpu/dl4mt/gru.cu gpu/dl4mt/model.cu + gpu/mblas/handles.cu gpu/mblas/matrix.cu gpu/mblas/matrix_functions.cu gpu/mblas/nth_element.cu @@ -118,6 +119,7 @@ cuda_add_library(python SHARED gpu/decoder/encoder_decoder_state.cu gpu/decoder/enc_out_buffer.cu gpu/decoder/enc_out_gpu.cu + gpu/mblas/handles.cu gpu/mblas/matrix.cu gpu/mblas/matrix_functions.cu gpu/mblas/nth_element.cu @@ -152,6 +154,7 @@ cuda_add_library(mosesplugin STATIC gpu/decoder/encoder_decoder_state.cu gpu/decoder/enc_out_buffer.cu gpu/decoder/enc_out_gpu.cu + gpu/mblas/handles.cu gpu/mblas/matrix.cu gpu/mblas/matrix_functions.cu gpu/mblas/nth_element.cu diff --git a/src/amun/gpu/mblas/handles.cu b/src/amun/gpu/mblas/handles.cu new file mode 100644 index 00000000..1d9b4ff2 --- /dev/null +++ b/src/amun/gpu/mblas/handles.cu @@ -0,0 +1,46 @@ +#include "handles.h" +#include "gpu/types-gpu.h" + +namespace amunmt { +namespace GPU { +namespace mblas { + +CudaStreamHandler::CudaStreamHandler() +{ + HANDLE_ERROR( cudaStreamCreate(&stream_)); + // cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking); +} + +CudaStreamHandler::~CudaStreamHandler() { + HANDLE_ERROR(cudaStreamDestroy(stream_)); +} + +/////////////////////////////////////////////////////////////// +cublasHandle_t &CublasHandler::GetHandle() { + return instance_.handle_; +} + +CublasHandler::CublasHandler() +{ + cublasStatus_t stat; + stat = cublasCreate(&handle_); + if (stat != CUBLAS_STATUS_SUCCESS) { + printf ("cublasCreate initialization failed\n"); + abort(); + } + + stat = cublasSetStream(handle_, CudaStreamHandler::GetStream()); + if (stat != CUBLAS_STATUS_SUCCESS) { + printf ("cublasSetStream initialization failed\n"); + abort(); + } +} + +CublasHandler::~CublasHandler() +{ + cublasDestroy(handle_); +} + +} +} +} diff --git a/src/amun/gpu/mblas/handles.h b/src/amun/gpu/mblas/handles.h index 8380652a..777d0f47 100644 --- a/src/amun/gpu/mblas/handles.h +++ b/src/amun/gpu/mblas/handles.h @@ -18,47 +18,21 @@ protected: static thread_local CudaStreamHandler instance_; cudaStream_t stream_; - CudaStreamHandler() - { - HANDLE_ERROR( cudaStreamCreate(&stream_)); - // cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking); - } - + CudaStreamHandler(); CudaStreamHandler(const CudaStreamHandler&) = delete; - virtual ~CudaStreamHandler() { - HANDLE_ERROR(cudaStreamDestroy(stream_)); - } + virtual ~CudaStreamHandler(); }; class CublasHandler { public: - static cublasHandle_t &GetHandle() { - return instance_.handle_; - } + static cublasHandle_t &GetHandle(); private: - CublasHandler() - { - cublasStatus_t stat; - stat = cublasCreate(&handle_); - if (stat != CUBLAS_STATUS_SUCCESS) { - printf ("cublasCreate initialization failed\n"); - abort(); - } - - stat = cublasSetStream(handle_, CudaStreamHandler::GetStream()); - if (stat != CUBLAS_STATUS_SUCCESS) { - printf ("cublasSetStream initialization failed\n"); - abort(); - } - } - - ~CublasHandler() { - cublasDestroy(handle_); - } + CublasHandler(); + ~CublasHandler(); static thread_local CublasHandler instance_; cublasHandle_t handle_; |