diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-24 21:44:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-24 21:44:56 +0300 |
commit | d65d29dab8c48a591eda2ad96ad8474eb629c82c (patch) | |
tree | a8ce7f390bd4d5981a95e40dd18b7c8d902d1b52 | |
parent | 488bc783b2b2ad5ea09718429fd5eacf9f80f804 (diff) | |
parent | a758aa66b18b0c3c1a857e561199f326749bc02e (diff) |
Merge pull request #733 from apaszke/btri
Make rinfo_ optional in btrifact
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 43 |
1 files changed, 20 insertions, 23 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index c180445..fb2b6b7 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -654,52 +654,49 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens lda = ra__->stride[2]; } + long num_batches = ra__->size[0]; + + THCudaIntTensor_resize2d(state, rpivots_, num_batches, n); + int *pivots_gpu = THCudaIntTensor_data(state, rpivots_); + + bool free_rinfo_ = !rinfo_; + if (rinfo_ == NULL) rinfo_ = THCudaIntTensor_new(state); + THCudaIntTensor_resize1d(state, rinfo_, num_batches); + int *info_gpu = THCudaIntTensor_data(state, rinfo_); + // Copy pointers to device. real **d_result; - long num_batches = ra__->size[0]; size_t matrices_size = num_batches * sizeof(real*); THCudaCheck(THCudaMalloc(state, (void**)&d_result, matrices_size)); - THCudaIntTensor_resize1d(state, rpivots_, num_batches*n); - - int *pivots_gpu; - THCudaCheck(THCudaMalloc(state, (void**)&pivots_gpu, sizeof(int)*num_batches*n)); - const long block = 512; const long grid = (num_batches + block - 1) / block; createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>( (const real**)d_result, THCTensor_(data)(state, ra__), ra__->stride[0], num_batches); - THCudaIntTensor_resize1d(state, rinfo_, num_batches); - int *info_gpu; - THCudaCheck(THCudaMalloc(state, (void**)&info_gpu, sizeof(int)*num_batches)); - #ifdef THC_REAL_IS_FLOAT THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); #elif defined(THC_REAL_IS_DOUBLE) THCudaBlas_Dgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); #endif - if (!THCudaIntTensor_isContiguous(state, rinfo_)) { - THError("Error: rinfo_ is not contiguous."); - } - THCudaCheck(cudaMemcpy(THCudaIntTensor_data(state, rinfo_), info_gpu, sizeof(int)*num_batches, cudaMemcpyDeviceToHost)); - - if (!THCudaIntTensor_isContiguous(state, rpivots_)) { - THError("Error: rpivots_ is not contiguous."); - } - THCudaCheck(cudaMemcpy(THCudaIntTensor_data(state, rpivots_), pivots_gpu, sizeof(int)*num_batches*n, cudaMemcpyDeviceToHost)); - THCudaIntTensor_resize2d(state, rpivots_, num_batches, n); - THCudaFree(state, d_result); - THCudaFree(state, info_gpu); - THCudaFree(state, pivots_gpu); if (ra__ != ra_) { THCTensor_(freeCopyTo)(state, ra__, ra_); } + if (free_rinfo_) { + real min = THCudaIntTensor_minall(state, rinfo_); + real max = THCudaIntTensor_maxall(state, rinfo_); + THCudaIntTensor_free(state, rinfo_); + if (min != 0 || max != 0) { + THError("failed to factorize some batch elements (min info == %d, max info == %d)", + min, max); + } + } + #else THError("unimplemented data type"); #endif |