Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-01-10 01:28:19 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-01-10 01:28:19 +0300
commitf9275f43ef66a462c67b4fc7efa08201eb61490e (patch)
tree006cde6b1e3388ea69055f62742175144fee1f77
parentbbb396c3036d3a22b769f8e670562b9a3db68cf1 (diff)
move .h -> .cu
-rw-r--r--contrib/other-builds/amunmt/.project5
-rw-r--r--src/amun/CMakeLists.txt3
-rw-r--r--src/amun/gpu/mblas/handles.cu46
-rw-r--r--src/amun/gpu/mblas/handles.h36
4 files changed, 59 insertions, 31 deletions
diff --git a/contrib/other-builds/amunmt/.project b/contrib/other-builds/amunmt/.project
index c76302c0..8c68f5d2 100644
--- a/contrib/other-builds/amunmt/.project
+++ b/contrib/other-builds/amunmt/.project
@@ -1561,6 +1561,11 @@
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.h</locationURI>
</link>
<link>
+ <name>src/amun/gpu/mblas/handles.cu</name>
+ <type>1</type>
+ <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/handles.cu</locationURI>
+ </link>
+ <link>
<name>src/amun/gpu/mblas/handles.h</name>
<type>1</type>
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/handles.h</locationURI>
diff --git a/src/amun/CMakeLists.txt b/src/amun/CMakeLists.txt
index 10b8b904..6a3e5d9d 100644
--- a/src/amun/CMakeLists.txt
+++ b/src/amun/CMakeLists.txt
@@ -92,6 +92,7 @@ cuda_add_executable(
gpu/dl4mt/encoder.cu
gpu/dl4mt/gru.cu
gpu/dl4mt/model.cu
+ gpu/mblas/handles.cu
gpu/mblas/matrix.cu
gpu/mblas/matrix_functions.cu
gpu/mblas/nth_element.cu
@@ -118,6 +119,7 @@ cuda_add_library(python SHARED
gpu/decoder/encoder_decoder_state.cu
gpu/decoder/enc_out_buffer.cu
gpu/decoder/enc_out_gpu.cu
+ gpu/mblas/handles.cu
gpu/mblas/matrix.cu
gpu/mblas/matrix_functions.cu
gpu/mblas/nth_element.cu
@@ -152,6 +154,7 @@ cuda_add_library(mosesplugin STATIC
gpu/decoder/encoder_decoder_state.cu
gpu/decoder/enc_out_buffer.cu
gpu/decoder/enc_out_gpu.cu
+ gpu/mblas/handles.cu
gpu/mblas/matrix.cu
gpu/mblas/matrix_functions.cu
gpu/mblas/nth_element.cu
diff --git a/src/amun/gpu/mblas/handles.cu b/src/amun/gpu/mblas/handles.cu
new file mode 100644
index 00000000..1d9b4ff2
--- /dev/null
+++ b/src/amun/gpu/mblas/handles.cu
@@ -0,0 +1,46 @@
+#include "handles.h"
+#include "gpu/types-gpu.h"
+
+namespace amunmt {
+namespace GPU {
+namespace mblas {
+
+CudaStreamHandler::CudaStreamHandler()
+{
+ HANDLE_ERROR( cudaStreamCreate(&stream_));
+ // cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking);
+}
+
+CudaStreamHandler::~CudaStreamHandler() {
+ HANDLE_ERROR(cudaStreamDestroy(stream_));
+}
+
+///////////////////////////////////////////////////////////////
+cublasHandle_t &CublasHandler::GetHandle() {
+ return instance_.handle_;
+}
+
+CublasHandler::CublasHandler()
+{
+ cublasStatus_t stat;
+ stat = cublasCreate(&handle_);
+ if (stat != CUBLAS_STATUS_SUCCESS) {
+ printf ("cublasCreate initialization failed\n");
+ abort();
+ }
+
+ stat = cublasSetStream(handle_, CudaStreamHandler::GetStream());
+ if (stat != CUBLAS_STATUS_SUCCESS) {
+ printf ("cublasSetStream initialization failed\n");
+ abort();
+ }
+}
+
+CublasHandler::~CublasHandler()
+{
+ cublasDestroy(handle_);
+}
+
+}
+}
+}
diff --git a/src/amun/gpu/mblas/handles.h b/src/amun/gpu/mblas/handles.h
index 8380652a..777d0f47 100644
--- a/src/amun/gpu/mblas/handles.h
+++ b/src/amun/gpu/mblas/handles.h
@@ -18,47 +18,21 @@ protected:
static thread_local CudaStreamHandler instance_;
cudaStream_t stream_;
- CudaStreamHandler()
- {
- HANDLE_ERROR( cudaStreamCreate(&stream_));
- // cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking);
- }
-
+ CudaStreamHandler();
CudaStreamHandler(const CudaStreamHandler&) = delete;
- virtual ~CudaStreamHandler() {
- HANDLE_ERROR(cudaStreamDestroy(stream_));
- }
+ virtual ~CudaStreamHandler();
};
class CublasHandler
{
public:
- static cublasHandle_t &GetHandle() {
- return instance_.handle_;
- }
+ static cublasHandle_t &GetHandle();
private:
- CublasHandler()
- {
- cublasStatus_t stat;
- stat = cublasCreate(&handle_);
- if (stat != CUBLAS_STATUS_SUCCESS) {
- printf ("cublasCreate initialization failed\n");
- abort();
- }
-
- stat = cublasSetStream(handle_, CudaStreamHandler::GetStream());
- if (stat != CUBLAS_STATUS_SUCCESS) {
- printf ("cublasSetStream initialization failed\n");
- abort();
- }
- }
-
- ~CublasHandler() {
- cublasDestroy(handle_);
- }
+ CublasHandler();
+ ~CublasHandler();
static thread_local CublasHandler instance_;
cublasHandle_t handle_;