diff options
Diffstat (limited to 'extern/ceres/internal/ceres/context_impl.cc')
-rw-r--r-- | extern/ceres/internal/ceres/context_impl.cc | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/extern/ceres/internal/ceres/context_impl.cc b/extern/ceres/internal/ceres/context_impl.cc index 20fe5cbab2a..a4b3c842da1 100644 --- a/extern/ceres/internal/ceres/context_impl.cc +++ b/extern/ceres/internal/ceres/context_impl.cc @@ -30,9 +30,75 @@ #include "ceres/context_impl.h" +#include <string> + +#include "ceres/internal/config.h" + +#ifndef CERES_NO_CUDA +#include "cublas_v2.h" +#include "cuda_runtime.h" +#include "cusolverDn.h" +#endif // CERES_NO_CUDA + namespace ceres { namespace internal { +ContextImpl::ContextImpl() = default; + +#ifndef CERES_NO_CUDA +bool ContextImpl::InitCUDA(std::string* message) { + if (cuda_initialized_) { + return true; + } + if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { + *message = "cuBLAS::cublasCreate failed."; + cublas_handle_ = nullptr; + return false; + } + if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) { + *message = "cuSolverDN::cusolverDnCreate failed."; + cusolver_handle_ = nullptr; + cublasDestroy(cublas_handle_); + cublas_handle_ = nullptr; + return false; + } + if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) != + cudaSuccess) { + *message = "CUDA::cudaStreamCreateWithFlags failed."; + cusolverDnDestroy(cusolver_handle_); + cublasDestroy(cublas_handle_); + cusolver_handle_ = nullptr; + cublas_handle_ = nullptr; + stream_ = nullptr; + return false; + } + if (cusolverDnSetStream(cusolver_handle_, stream_) != + CUSOLVER_STATUS_SUCCESS || + cublasSetStream(cublas_handle_, stream_) != CUBLAS_STATUS_SUCCESS) { + *message = + "cuSolverDN::cusolverDnSetStream or cuBLAS::cublasSetStream failed."; + cusolverDnDestroy(cusolver_handle_); + cublasDestroy(cublas_handle_); + cudaStreamDestroy(stream_); + cusolver_handle_ = nullptr; + cublas_handle_ = nullptr; + stream_ = nullptr; + return false; + } + cuda_initialized_ = true; + return true; +} +#endif // CERES_NO_CUDA + +ContextImpl::~ContextImpl() { +#ifndef CERES_NO_CUDA + if (cuda_initialized_) { + cusolverDnDestroy(cusolver_handle_); + cublasDestroy(cublas_handle_); + cudaStreamDestroy(stream_); + } +#endif // CERES_NO_CUDA +} void ContextImpl::EnsureMinimumThreads(int num_threads) { #ifdef CERES_USE_CXX_THREADS thread_pool.Resize(num_threads); |