blob: baa966396b70c30a2663f5b9045fc30bfc21dc6a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
#include "handles.h"
#include "gpu/types-gpu.h"
namespace amunmt {
namespace GPU {
namespace mblas {
CudaStreamHandler::CudaStreamHandler()
{
HANDLE_ERROR( cudaStreamCreate(&stream_));
// cudaStreamCreateWithFlags(stream_.get(), cudaStreamNonBlocking);
}
CudaStreamHandler::~CudaStreamHandler()
{
HANDLE_ERROR(cudaStreamDestroy(stream_));
}
/////////////////////////////////////////////////////////////////////////////////////////
CublasHandler::CublasHandler()
{
cublasStatus_t stat;
stat = cublasCreate(&handle_);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasCreate initialization failed\n");
abort();
}
#if CUDA_VERSION >= 9000
/*
stat = cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetMathMode failed\n");
abort();
}
*/
#endif
stat = cublasSetStream(handle_, CudaStreamHandler::GetStream());
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream initialization failed\n");
abort();
}
}
CublasHandler::~CublasHandler() {
cublasDestroy(handle_);
}
} // namespace
}
}
|