Welcome to mirror list, hosted at ThFree Co, Russian Federation.

git.blender.org/blender.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'extern/ceres/internal/ceres/context_impl.cc')
-rw-r--r--extern/ceres/internal/ceres/context_impl.cc66
1 files changed, 66 insertions, 0 deletions
diff --git a/extern/ceres/internal/ceres/context_impl.cc b/extern/ceres/internal/ceres/context_impl.cc
index 20fe5cbab2a..a4b3c842da1 100644
--- a/extern/ceres/internal/ceres/context_impl.cc
+++ b/extern/ceres/internal/ceres/context_impl.cc
@@ -30,9 +30,75 @@
#include "ceres/context_impl.h"
+#include <string>
+
+#include "ceres/internal/config.h"
+
+#ifndef CERES_NO_CUDA
+#include "cublas_v2.h"
+#include "cuda_runtime.h"
+#include "cusolverDn.h"
+#endif // CERES_NO_CUDA
+
namespace ceres {
namespace internal {
+ContextImpl::ContextImpl() = default;
+
+#ifndef CERES_NO_CUDA
+bool ContextImpl::InitCUDA(std::string* message) {
+ if (cuda_initialized_) {
+ return true;
+ }
+ if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
+ *message = "cuBLAS::cublasCreate failed.";
+ cublas_handle_ = nullptr;
+ return false;
+ }
+ if (cusolverDnCreate(&cusolver_handle_) != CUSOLVER_STATUS_SUCCESS) {
+ *message = "cuSolverDN::cusolverDnCreate failed.";
+ cusolver_handle_ = nullptr;
+ cublasDestroy(cublas_handle_);
+ cublas_handle_ = nullptr;
+ return false;
+ }
+ if (cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking) !=
+ cudaSuccess) {
+ *message = "CUDA::cudaStreamCreateWithFlags failed.";
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cusolver_handle_ = nullptr;
+ cublas_handle_ = nullptr;
+ stream_ = nullptr;
+ return false;
+ }
+ if (cusolverDnSetStream(cusolver_handle_, stream_) !=
+ CUSOLVER_STATUS_SUCCESS ||
+ cublasSetStream(cublas_handle_, stream_) != CUBLAS_STATUS_SUCCESS) {
+ *message =
+ "cuSolverDN::cusolverDnSetStream or cuBLAS::cublasSetStream failed.";
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cudaStreamDestroy(stream_);
+ cusolver_handle_ = nullptr;
+ cublas_handle_ = nullptr;
+ stream_ = nullptr;
+ return false;
+ }
+ cuda_initialized_ = true;
+ return true;
+}
+#endif // CERES_NO_CUDA
+
+ContextImpl::~ContextImpl() {
+#ifndef CERES_NO_CUDA
+ if (cuda_initialized_) {
+ cusolverDnDestroy(cusolver_handle_);
+ cublasDestroy(cublas_handle_);
+ cudaStreamDestroy(stream_);
+ }
+#endif // CERES_NO_CUDA
+}
void ContextImpl::EnsureMinimumThreads(int num_threads) {
#ifdef CERES_USE_CXX_THREADS
thread_pool.Resize(num_threads);