Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-03-24 21:44:56 +0300
committerGitHub <noreply@github.com>2017-03-24 21:44:56 +0300
commitd65d29dab8c48a591eda2ad96ad8474eb629c82c (patch)
treea8ce7f390bd4d5981a95e40dd18b7c8d902d1b52
parent488bc783b2b2ad5ea09718429fd5eacf9f80f804 (diff)
parenta758aa66b18b0c3c1a857e561199f326749bc02e (diff)
Merge pull request #733 from apaszke/btri
Make rinfo_ optional in btrifact
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu43
1 files changed, 20 insertions, 23 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index c180445..fb2b6b7 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -654,52 +654,49 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
lda = ra__->stride[2];
}
+ long num_batches = ra__->size[0];
+
+ THCudaIntTensor_resize2d(state, rpivots_, num_batches, n);
+ int *pivots_gpu = THCudaIntTensor_data(state, rpivots_);
+
+ bool free_rinfo_ = !rinfo_;
+ if (rinfo_ == NULL) rinfo_ = THCudaIntTensor_new(state);
+ THCudaIntTensor_resize1d(state, rinfo_, num_batches);
+ int *info_gpu = THCudaIntTensor_data(state, rinfo_);
+
// Copy pointers to device.
real **d_result;
- long num_batches = ra__->size[0];
size_t matrices_size = num_batches * sizeof(real*);
THCudaCheck(THCudaMalloc(state, (void**)&d_result, matrices_size));
- THCudaIntTensor_resize1d(state, rpivots_, num_batches*n);
-
- int *pivots_gpu;
- THCudaCheck(THCudaMalloc(state, (void**)&pivots_gpu, sizeof(int)*num_batches*n));
-
const long block = 512;
const long grid = (num_batches + block - 1) / block;
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
(const real**)d_result, THCTensor_(data)(state, ra__),
ra__->stride[0], num_batches);
- THCudaIntTensor_resize1d(state, rinfo_, num_batches);
- int *info_gpu;
- THCudaCheck(THCudaMalloc(state, (void**)&info_gpu, sizeof(int)*num_batches));
-
#ifdef THC_REAL_IS_FLOAT
THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches);
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_Dgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches);
#endif
- if (!THCudaIntTensor_isContiguous(state, rinfo_)) {
- THError("Error: rinfo_ is not contiguous.");
- }
- THCudaCheck(cudaMemcpy(THCudaIntTensor_data(state, rinfo_), info_gpu, sizeof(int)*num_batches, cudaMemcpyDeviceToHost));
-
- if (!THCudaIntTensor_isContiguous(state, rpivots_)) {
- THError("Error: rpivots_ is not contiguous.");
- }
- THCudaCheck(cudaMemcpy(THCudaIntTensor_data(state, rpivots_), pivots_gpu, sizeof(int)*num_batches*n, cudaMemcpyDeviceToHost));
- THCudaIntTensor_resize2d(state, rpivots_, num_batches, n);
-
THCudaFree(state, d_result);
- THCudaFree(state, info_gpu);
- THCudaFree(state, pivots_gpu);
if (ra__ != ra_) {
THCTensor_(freeCopyTo)(state, ra__, ra_);
}
+ if (free_rinfo_) {
+ real min = THCudaIntTensor_minall(state, rinfo_);
+ real max = THCudaIntTensor_maxall(state, rinfo_);
+ THCudaIntTensor_free(state, rinfo_);
+ if (min != 0 || max != 0) {
+ THError("failed to factorize some batch elements (min info == %d, max info == %d)",
+ min, max);
+ }
+ }
+
#else
THError("unimplemented data type");
#endif