diff options
author | Brandon Amos <bamos@cs.cmu.edu> | 2017-06-13 20:45:40 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-22 19:45:01 +0300 |
commit | 860cdfdc508800ac45f4c8d5c3ab996a2e882e89 (patch) | |
tree | 83c9b49d85860caa3a498a0338a15f08e203a3ae | |
parent | 18d390c33942f58dd13d1a3455dc3224b209cd21 (diff) |
btrifact: Make pivoting optional.
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 22 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.h | 2 |
2 files changed, 20 insertions, 4 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index 0d47750..a6aa074 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -619,7 +619,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, #endif } -THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a) +THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) THAssert(THCTensor_(checkGPU)(state, 2, ra_, a)); @@ -658,8 +658,20 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens long num_batches = ra__->size[0]; - THCudaIntTensor_resize2d(state, rpivots_, num_batches, n); - int *pivots_gpu = THCudaIntTensor_data(state, rpivots_); + if (!pivot) { + THCudaIntTensor *t = THCudaIntTensor_new(state); + THCudaIntTensor_range(state, t, 1, n, 1); + THCudaIntTensor_unsqueeze1d(state, t, t, 0); + THCudaIntTensor** ptrs = (THCudaIntTensor**) THAlloc(sizeof(THCudaIntTensor*)*num_batches); + for (long i=0; i<num_batches; i++) { + ptrs[i] = t; + } + THCudaIntTensor_catArray(state, rpivots_, ptrs, num_batches, 0); + THCudaIntTensor_free(state, t); + THFree(ptrs); + } else { + THCudaIntTensor_resize2d(state, rpivots_, num_batches, n); + } bool free_rinfo_ = !rinfo_; if (rinfo_ == NULL) rinfo_ = THCudaIntTensor_new(state); @@ -677,6 +689,10 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens (const real**)d_result, THCTensor_(data)(state, ra__), ra__->stride[0], num_batches); + int *pivots_gpu = NULL; + if (pivot) { + pivots_gpu = THCudaIntTensor_data(state, rpivots_); + } #ifdef THC_REAL_IS_FLOAT THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); #elif defined(THC_REAL_IS_DOUBLE) diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h index 1d9ddfa..1279d7e 100644 --- a/lib/THC/generic/THCTensorMathBlas.h +++ b/lib/THC/generic/THCTensorMathBlas.h @@ -9,7 +9,7 @@ THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTe THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2); THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2); -THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a); +THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a); THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots); |