Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrandon Amos <bamos@cs.cmu.edu>2017-06-13 20:45:40 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-22 19:45:01 +0300
commit860cdfdc508800ac45f4c8d5c3ab996a2e882e89 (patch)
tree83c9b49d85860caa3a498a0338a15f08e203a3ae
parent18d390c33942f58dd13d1a3455dc3224b209cd21 (diff)
btrifact: Make pivoting optional.
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu22
-rw-r--r--lib/THC/generic/THCTensorMathBlas.h2
2 files changed, 20 insertions, 4 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index 0d47750..a6aa074 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -619,7 +619,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
#endif
}
-THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a)
+THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a)
{
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
THAssert(THCTensor_(checkGPU)(state, 2, ra_, a));
@@ -658,8 +658,20 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
long num_batches = ra__->size[0];
- THCudaIntTensor_resize2d(state, rpivots_, num_batches, n);
- int *pivots_gpu = THCudaIntTensor_data(state, rpivots_);
+ if (!pivot) {
+ THCudaIntTensor *t = THCudaIntTensor_new(state);
+ THCudaIntTensor_range(state, t, 1, n, 1);
+ THCudaIntTensor_unsqueeze1d(state, t, t, 0);
+ THCudaIntTensor** ptrs = (THCudaIntTensor**) THAlloc(sizeof(THCudaIntTensor*)*num_batches);
+ for (long i=0; i<num_batches; i++) {
+ ptrs[i] = t;
+ }
+ THCudaIntTensor_catArray(state, rpivots_, ptrs, num_batches, 0);
+ THCudaIntTensor_free(state, t);
+ THFree(ptrs);
+ } else {
+ THCudaIntTensor_resize2d(state, rpivots_, num_batches, n);
+ }
bool free_rinfo_ = !rinfo_;
if (rinfo_ == NULL) rinfo_ = THCudaIntTensor_new(state);
@@ -677,6 +689,10 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
(const real**)d_result, THCTensor_(data)(state, ra__),
ra__->stride[0], num_batches);
+ int *pivots_gpu = NULL;
+ if (pivot) {
+ pivots_gpu = THCudaIntTensor_data(state, rpivots_);
+ }
#ifdef THC_REAL_IS_FLOAT
THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches);
#elif defined(THC_REAL_IS_DOUBLE)
diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h
index 1d9ddfa..1279d7e 100644
--- a/lib/THC/generic/THCTensorMathBlas.h
+++ b/lib/THC/generic/THCTensorMathBlas.h
@@ -9,7 +9,7 @@ THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTe
THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2);
THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2);
-THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, THCTensor *a);
+THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a);
THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots);