Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-01-16 17:12:24 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-01-16 17:12:24 +0300
commit7b4ef6322ef925737d4eb7e257ec93ac77eaa518 (patch)
tree55aafd8d7d06cac19b70e6d839ed6c0055f50af5
parentebc784360c963e55f5dcc122eda88bdffd686115 (diff)
add tensorcore support
-rw-r--r--src/amun/gpu/mblas/handles.cu18
1 files changed, 14 insertions, 4 deletions
diff --git a/src/amun/gpu/mblas/handles.cu b/src/amun/gpu/mblas/handles.cu
index 8f17e232..13c99f8c 100644
--- a/src/amun/gpu/mblas/handles.cu
+++ b/src/amun/gpu/mblas/handles.cu
@@ -23,14 +23,24 @@ CublasHandler::CublasHandler()
cublasStatus_t stat;
stat = cublasCreate(&handle_);
if (stat != CUBLAS_STATUS_SUCCESS) {
- printf ("cublasCreate initialization failed\n");
- abort();
+ printf ("cublasCreate initialization failed\n");
+ abort();
}
+#if CUDA_VERSION >= 9000
+ ///*
+ stat = cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH);
+ if (stat != CUBLAS_STATUS_SUCCESS) {
+ printf ("cublasSetMathMode failed\n");
+ abort();
+ }
+ //*/
+#endif
+
stat = cublasSetStream(handle_, CudaStreamHandler::GetStream());
if (stat != CUBLAS_STATUS_SUCCESS) {
- printf ("cublasSetStream initialization failed\n");
- abort();
+ printf ("cublasSetStream initialization failed\n");
+ abort();
}
}