From 860cdfdc508800ac45f4c8d5c3ab996a2e882e89 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Tue, 13 Jun 2017 13:45:40 -0400 Subject: btrifact: Make pivoting optional. --- lib/THC/generic/THCTensorMathBlas.cu | 22 +++++++++++++++++++--- lib/THC/generic/THCTensorMathBlas.h | 2 +- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index 0d47750..a6aa074 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -619,7 +619,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, #endif } -THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a) +THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) THAssert(THCTensor_(checkGPU)(state, 2, ra_, a)); @@ -658,8 +658,20 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens long num_batches = ra__->size[0]; - THCudaIntTensor_resize2d(state, rpivots_, num_batches, n); - int *pivots_gpu = THCudaIntTensor_data(state, rpivots_); + if (!pivot) { + THCudaIntTensor *t = THCudaIntTensor_new(state); + THCudaIntTensor_range(state, t, 1, n, 1); + THCudaIntTensor_unsqueeze1d(state, t, t, 0); + THCudaIntTensor** ptrs = (THCudaIntTensor**) THAlloc(sizeof(THCudaIntTensor*)*num_batches); + for (long i=0; istride[0], num_batches); + int *pivots_gpu = NULL; + if (pivot) { + pivots_gpu = THCudaIntTensor_data(state, rpivots_); + } #ifdef THC_REAL_IS_FLOAT THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); #elif defined(THC_REAL_IS_DOUBLE) diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h index 1d9ddfa..1279d7e 100644 --- a/lib/THC/generic/THCTensorMathBlas.h +++ b/lib/THC/generic/THCTensorMathBlas.h @@ -9,7 +9,7 @@ THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTe THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2); THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2); -THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a); +THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a); THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots); -- cgit v1.2.3