Welcome to mirror list, hosted at ThFree Co, Russian Federation.

handles.cu « mblas « gpu « amun « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: baa966396b70c30a2663f5b9045fc30bfc21dc6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include "handles.h"
#include "gpu/types-gpu.h"

namespace amunmt {
namespace GPU {
namespace mblas {

CudaStreamHandler::CudaStreamHandler()
{
  HANDLE_ERROR( cudaStreamCreate(&stream_));
  // cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking);
}

CudaStreamHandler::~CudaStreamHandler()
{
  HANDLE_ERROR(cudaStreamDestroy(stream_));
}

/////////////////////////////////////////////////////////////////////////////////////////

CublasHandler::CublasHandler()
{
  cublasStatus_t stat;
  stat = cublasCreate(&handle_);
  if (stat != CUBLAS_STATUS_SUCCESS) {
    printf ("cublasCreate initialization failed\n");
    abort();
  }

#if CUDA_VERSION >= 9000
  /*
    stat = cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH);
    if (stat != CUBLAS_STATUS_SUCCESS) {
      printf ("cublasSetMathMode failed\n");
      abort();
    }
  */
#endif
		  
  stat = cublasSetStream(handle_, CudaStreamHandler::GetStream());
  if (stat != CUBLAS_STATUS_SUCCESS) {
    printf ("cublasSetStream initialization failed\n");
    abort();
  }
}

CublasHandler::~CublasHandler() {
  cublasDestroy(handle_);
}


} // namespace
}
}