Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-15 23:58:05 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:31:32 +0300
commit5c719c51da75e72430bd3ff5b45ce2d062c7c349 (patch)
tree5caa43c5719349ab14531fc33315a255224e1d6e
parent26398da39ef4a423c922d9b60b1ac9c24974e3ed (diff)
[cutorch mag2gen] move inverse to generic
-rw-r--r--TensorMath.lua8
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMathMagma.cu99
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu117
-rw-r--r--lib/THC/generic/THCTensorMathMagma.h1
-rw-r--r--test/test.lua9
6 files changed, 132 insertions, 103 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index aa0b626..0b8bad8 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1229,6 +1229,14 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor},
{name='charoption', values={'A', 'S'}, default='S'}})
+ wrap("inverse",
+ cname("getri"),
+ {{name=Tensor, returned=true},
+ {name=Tensor}},
+ cname("getri"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor}})
+
end
wrap("dot",
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 748b72f..7e02c6b 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -45,7 +45,6 @@
#include "THCGenerateAllTypes.h"
// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
THC_API void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo);
THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo);
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b, const char *uplo);
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index f740836..a60b792 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,105 +23,6 @@ void THCMagma_init(THCState *state)
#endif
}
-void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a)
-{
-#ifdef USE_MAGMA
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
-
- int info;
- int n = a->size[0];
- int lwork = n * magma_get_sgetri_nb(n);
-
- THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a);
- float *input_data = THCudaTensor_data(state, input);
-
- int *ipiv = th_magma_malloc_pinned<int>(n);
-
- THCudaTensor *work = THCudaTensor_newWithSize1d(state, lwork);
- float *work_data = THCudaTensor_data(state, work);
-
- // Run LU
- magma_sgetrf_gpu(n, n, input_data, n, ipiv, &info);
- if (info > 0)
- THError("MAGMA getrf : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("MAGMA getrf : Argument %d : illegal value", -info);
-
- // Inverse
- magma_sgetri_gpu(n, input_data, n, ipiv, work_data, lwork, &info);
- if (info > 0)
- THError("MAGMA getri : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("MAGMA getri : Argument %d : illegal value", -info);
-
- THCudaTensor_free(state, work);
- magma_free_pinned(ipiv);
- THCudaTensor_freeCopyTo(state, input, ra_);
-#else
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
-
- int n = a->size[0];
-
- // input
- THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a);
- // output
- THCudaTensor *output = THCudaTensor_newColumnMajor(state, ra_, a);
-
- size_t matrices_size = sizeof(float*);
-
- float **matrices1 = (float **)THAlloc(matrices_size);
- const float **matrices1_const = (const float **)THAlloc(matrices_size);
- float **matrices2 = (float **)THAlloc(matrices_size);
- matrices1[0] = THCudaTensor_data(state, input);
- matrices1_const[0] = THCudaTensor_data(state, input);
- matrices2[0] = THCudaTensor_data(state, output);
-
- // Copy pointers to device.
- float **d_matrices1, **d_matrices2;
- const float **d_matrices1_const;
- THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
- THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1_const, matrices_size));
- THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
-
- THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- THCudaCheck(cudaMemcpyAsync(d_matrices1_const, matrices1_const, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- int info;
- int *info_gpu;
- THCudaCheck(THCudaMalloc(state, (void**)&info_gpu, sizeof(int)));
-
- int *ipiv_gpu;
- THCudaCheck(THCudaMalloc(state, (void**)&ipiv_gpu, n * sizeof(int)));
-
- // Run LU
- THCudaBlas_Sgetrf(state, n, d_matrices1, n, ipiv_gpu, info_gpu, 1);
-
- THCudaCheck(cudaMemcpy(&info, info_gpu, sizeof(int), cudaMemcpyDeviceToHost));
-
- if (info > 0)
- THError("CUBLAS getrf : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("CUBLAS getrf : Argument %d : illegal value", -info);
-
- // Inverse
- THCudaBlas_Sgetri(state, n, d_matrices1_const, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
- if (info > 0)
- THError("CUBLAS getri : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("CUBLAS getri : Argument %d : illegal value", -info);
-
- THCudaCheck(THCudaFree(state, ipiv_gpu));
- THCudaCheck(THCudaFree(state, info_gpu));
- THCudaTensor_freeCopyTo(state, output, input);
-#endif
-
-}
-
__global__ void THCudaTensor_copyUpperSymmetric(float *input, int n, int len)
{
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index f469652..36e7e0d 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -284,6 +284,123 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_,
#endif
}
+THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+
+ int info;
+ int n = a->size[0];
+ int lwork = n * magma_get_sgetri_nb(n);
+
+ THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
+ real *input_data = THCTensor_(data)(state, input);
+
+ int *ipiv = th_magma_malloc_pinned<int>(n);
+
+ THCTensor *work = THCTensor_(newWithSize1d)(state, lwork);
+ real *work_data = THCTensor_(data)(state, work);
+
+ // Run LU
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgetrf_gpu(n, n, input_data, n, ipiv, &info);
+#else
+ magma_dgetrf_gpu(n, n, input_data, n, ipiv, &info);
+#endif
+
+ if (info > 0)
+ THError("MAGMA getrf : U(%d,%d) is 0, U is singular", info, info);
+ else if (info < 0)
+ THError("MAGMA getrf : Argument %d : illegal value", -info);
+
+ // Inverse
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgetri_gpu(n, input_data, n, ipiv, work_data, lwork, &info);
+#else
+ magma_dgetri_gpu(n, input_data, n, ipiv, work_data, lwork, &info);
+#endif
+
+ if (info > 0)
+ THError("MAGMA getri : U(%d,%d) is 0, U is singular", info, info);
+ else if (info < 0)
+ THError("MAGMA getri : Argument %d : illegal value", -info);
+
+ THCTensor_(free)(state, work);
+ magma_free_pinned(ipiv);
+ THCTensor_(freeCopyTo)(state, input, ra_);
+#else
+ THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+
+ int n = a->size[0];
+
+ // input
+ THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
+ // output
+ THCTensor *output = THCTensor_(newColumnMajor)(state, ra_, a);
+
+ size_t matrices_size = sizeof(real*);
+
+ real **matrices1 = (real **)THAlloc(matrices_size);
+ const real **matrices1_const = (const real **)THAlloc(matrices_size);
+ real **matrices2 = (real **)THAlloc(matrices_size);
+ matrices1[0] = THCTensor_(data)(state, input);
+ matrices1_const[0] = THCTensor_(data)(state, input);
+ matrices2[0] = THCTensor_(data)(state, output);
+
+ // Copy pointers to device.
+ real **d_matrices1, **d_matrices2;
+ const real **d_matrices1_const;
+ THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
+ THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1_const, matrices_size));
+ THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
+
+ THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ THCudaCheck(cudaMemcpyAsync(d_matrices1_const, matrices1_const, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size,
+ cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ int info;
+ int *info_gpu;
+ THCudaCheck(THCudaMalloc(state, (void**)&info_gpu, sizeof(int)));
+
+ int *ipiv_gpu;
+ THCudaCheck(THCudaMalloc(state, (void**)&ipiv_gpu, n * sizeof(int)));
+
+ // Run LU
+#if defined(THC_REAL_IS_FLOAT)
+ THCudaBlas_Sgetrf(state, n, d_matrices1, n, ipiv_gpu, info_gpu, 1);
+#else
+ THCudaBlas_Dgetrf(state, n, d_matrices1, n, ipiv_gpu, info_gpu, 1);
+#endif
+
+ THCudaCheck(cudaMemcpy(&info, info_gpu, sizeof(int), cudaMemcpyDeviceToHost));
+
+ if (info > 0)
+ THError("CUBLAS getrf : U(%d,%d) is 0, U is singular", info, info);
+ else if (info < 0)
+ THError("CUBLAS getrf : Argument %d : illegal value", -info);
+
+ // Inverse
+#if defined(THC_REAL_IS_FLOAT)
+ THCudaBlas_Sgetri(state, n, d_matrices1_const, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
+#else
+ THCudaBlas_Dgetri(state, n, d_matrices1_const, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
+#endif
+
+ if (info > 0)
+ THError("CUBLAS getri : U(%d,%d) is 0, U is singular", info, info);
+ else if (info < 0)
+ THError("CUBLAS getri : Argument %d : illegal value", -info);
+
+ THCudaCheck(THCudaFree(state, ipiv_gpu));
+ THCudaCheck(THCudaFree(state, info_gpu));
+ THCTensor_(freeCopyTo)(state, output, input);
+#endif
+}
+
#endif
#endif
diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h
index 93a349a..4c8dbe3 100644
--- a/lib/THC/generic/THCTensorMathMagma.h
+++ b/lib/THC/generic/THCTensorMathMagma.h
@@ -65,6 +65,7 @@ THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, T
THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr);
THC_API void THCTensor_(gesvd)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *a, const char *jobu);
THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *ra_, THCTensor *a, const char *jobu);
+THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
diff --git a/test/test.lua b/test/test.lua
index 5f28023..e766de3 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2325,9 +2325,12 @@ end
function test.inverse()
local a = torch.eye(5):add(torch.Tensor(5, 5):uniform(-0.1, 0.1))
- local i1 = torch.inverse(a)
- local i2 = torch.inverse(a:cuda())
- tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong inverse answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = a:type(typename)
+ local i1 = torch.inverse(at)
+ local i2 = torch.inverse(at:cuda())
+ tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong inverse answer")
+ end
end
if cutorch.magma then