diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:08:27 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:31:32 +0300 |
commit | 7250cc589f279bf0b3dd61563c7ce8087e7e63c6 (patch) | |
tree | 182801d89c0b089a2010f6341fa26da383891be2 | |
parent | 5c719c51da75e72430bd3ff5b45ce2d062c7c349 (diff) |
[cutorch mag2gen] move potr* to generic
-rw-r--r-- | TensorMath.lua | 32 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 3 | ||||
-rw-r--r-- | lib/THC/THCTensorMathMagma.cu | 117 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 131 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.h | 3 | ||||
-rw-r--r-- | test/test.lua | 46 |
6 files changed, 194 insertions, 138 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 0b8bad8..0db19a7 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1237,6 +1237,38 @@ for k, Tensor_ in pairs(handledTypenames) do {{name=Tensor, default=true, returned=true, invisible=true}, {name=Tensor}}) + wrap("potri", + cname("potri"), + {{name=Tensor, returned=true}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}, + cname("potri"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}) + + wrap("potrf", + cname("potrf"), + {{name=Tensor, returned=true}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}, + cname("potrf"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}) + + wrap("potrs", + cname("potrs"), + {{name=Tensor, returned=true}, + {name=Tensor}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}, + cname("potrs"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name=Tensor}, + {name='charoption', values={'U', 'L'}, default='U'}}) + end wrap("dot", diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 7e02c6b..0850e3c 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -45,9 +45,6 @@ #include "THCGenerateAllTypes.h" // MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo); -THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo); -THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b, const char *uplo); THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a); diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu index a60b792..7edcae9 100644 --- a/lib/THC/THCTensorMathMagma.cu +++ b/lib/THC/THCTensorMathMagma.cu @@ -23,123 +23,6 @@ void THCMagma_init(THCState *state) #endif } -__global__ void THCudaTensor_copyUpperSymmetric(float *input, int n, int len) -{ - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) { - const int r = idx % n; - const int c = idx / n; - if (r > c) { - input[idx] = input[r*n + c]; - } - } -} - -__global__ void THCudaTensor_copyLowerSymmetric(float *input, int n, int len) -{ - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) { - const int r = idx % n; - const int c = idx / n; - if (r < c) { - input[idx] = input[r*n + c]; - } - } -} - -void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo) -{ -#ifdef USE_MAGMA - THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); - THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); - - int n = a->size[0]; - magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; - - THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a); - float *input_data = THCudaTensor_data(state, input); - - int info; - magma_spotri_gpu(ul, n, input_data, n, &info); - if (info > 0) - THError("MAGMA potri : A(%d,%d) is 0, A cannot be factorized", info, info); - else if (info < 0) - THError("MAGMA potri : Argument %d : illegal value", -info); - - cudaStream_t stream = THCState_getCurrentStream(state); - const int len = n*n; - dim3 blocks(std::min(DIVUP(len, 128), 65535)); - dim3 threads(128); - if (uplo[0] == 'U') { - THCudaTensor_copyUpperSymmetric<<<blocks, threads, 0, stream>>>(input_data, n, len); - } else { - THCudaTensor_copyLowerSymmetric<<<blocks, threads, 0, stream>>>(input_data, n, len); - } - - THCudaTensor_freeCopyTo(state, input, ra_); -#else - THError(NoMagma(potri)); -#endif -} - -void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo) -{ -#ifdef USE_MAGMA - THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); - THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); - - int n = a->size[0]; - magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; - - THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a); - float *input_data = THCudaTensor_data(state, input); - - int info; - magma_spotrf_gpu(ul, n, input_data, n, &info); - - // check error value - if (info > 0) - THError("MAGMA potrf : A(%d,%d) is 0, A cannot be factorized", info, info); - else if (info < 0) - THError("MAGMA potrf : Argument %d : illegal value", -info); - - if (uplo[0] == 'U') { - THCudaTensor_triu(state, ra_, input, 0); - } else { - THCudaTensor_tril(state, ra_, input, 0); - } - THCudaTensor_free(state, input); -#else - THError(NoMagma(potrf)); -#endif -} - -void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *b, THCudaTensor *a, const char *uplo) -{ -#ifdef USE_MAGMA - THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); - - int n = a->size[0]; - int nrhs = b->size[1]; - magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; - - THCudaTensor *b_ = THCudaTensor_newColumnMajor(state, rb_, b); - float *b_data = THCudaTensor_data(state, b_); - THCudaTensor *a_ = THCudaTensor_newColumnMajor(state, a, a); - float *a_data = THCudaTensor_data(state, a_); - - int info; - magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info); - - // check error value - if (info < 0) - THError("MAGMA potrs : Argument %d : illegal value", -info); - - THCudaTensor_freeCopyTo(state, b_, rb_); - THCudaTensor_free(state, a_); -#else - THError(NoMagma(potrs)); -#endif -} - void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a_) { #ifdef USE_MAGMA diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index 36e7e0d..41e7569 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -401,6 +401,137 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a) #endif } +__global__ void THCTensor_(copyUpperSymmetric)(real *input, int n, int len) +{ + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) { + const int r = idx % n; + const int c = idx / n; + if (r > c) { + input[idx] = input[r*n + c]; + } + } +} + +__global__ void THCTensor_(copyLowerSymmetric)(real *input, int n, int len) +{ + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) { + const int r = idx % n; + const int c = idx / n; + if (r < c) { + input[idx] = input[r*n + c]; + } + } +} + +void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) +{ +#ifdef USE_MAGMA + THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); + THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); + + int n = a->size[0]; + magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; + + THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a); + real *input_data = THCTensor_(data)(state, input); + + int info; +#if defined(THC_REAL_IS_FLOAT) + magma_spotri_gpu(ul, n, input_data, n, &info); +#else + magma_dpotri_gpu(ul, n, input_data, n, &info); +#endif + + if (info > 0) + THError("MAGMA potri : A(%d,%d) is 0, A cannot be factorized", info, info); + else if (info < 0) + THError("MAGMA potri : Argument %d : illegal value", -info); + + cudaStream_t stream = THCState_getCurrentStream(state); + const int len = n*n; + dim3 blocks(std::min(DIVUP(len, 128), 65535)); + dim3 threads(128); + if (uplo[0] == 'U') { + THCTensor_(copyUpperSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len); + } else { + THCTensor_(copyLowerSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len); + } + + THCTensor_(freeCopyTo)(state, input, ra_); +#else + THError(NoMagma(potri)); +#endif +} + +void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) +{ +#ifdef USE_MAGMA + THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); + THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); + + int n = a->size[0]; + magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; + + THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a); + real *input_data = THCTensor_(data)(state, input); + + int info; +#if defined(THC_REAL_IS_FLOAT) + magma_spotrf_gpu(ul, n, input_data, n, &info); +#else + magma_dpotrf_gpu(ul, n, input_data, n, &info); +#endif + + // check error value + if (info > 0) + THError("MAGMA potrf : A(%d,%d) is 0, A cannot be factorized", info, info); + else if (info < 0) + THError("MAGMA potrf : Argument %d : illegal value", -info); + + if (uplo[0] == 'U') { + THCTensor_(triu)(state, ra_, input, 0); + } else { + THCTensor_(tril)(state, ra_, input, 0); + } + THCTensor_(free)(state, input); +#else + THError(NoMagma(potrf)); +#endif +} + +void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo) +{ +#ifdef USE_MAGMA + THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); + + int n = a->size[0]; + int nrhs = b->size[1]; + magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; + + THCTensor *b_ = THCTensor_(newColumnMajor)(state, rb_, b); + real *b_data = THCTensor_(data)(state, b_); + THCTensor *a_ = THCTensor_(newColumnMajor)(state, a, a); + real *a_data = THCTensor_(data)(state, a_); + + int info; +#if defined(THC_REAL_IS_FLOAT) + magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info); +#else + magma_dpotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info); +#endif + + // check error value + if (info < 0) + THError("MAGMA potrs : Argument %d : illegal value", -info); + + THCTensor_(freeCopyTo)(state, b_, rb_); + THCTensor_(free)(state, a_); +#else + THError(NoMagma(potrs)); +#endif +} + + #endif #endif diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h index 4c8dbe3..ce0ed29 100644 --- a/lib/THC/generic/THCTensorMathMagma.h +++ b/lib/THC/generic/THCTensorMathMagma.h @@ -66,6 +66,9 @@ THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, T THC_API void THCTensor_(gesvd)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *a, const char *jobu); THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *ra_, THCTensor *a, const char *jobu); THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a); +THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo); +THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo); +THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo); #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/test/test.lua b/test/test.lua index e766de3..c49e17a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2328,7 +2328,7 @@ function test.inverse() for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do local at = a:type(typename) local i1 = torch.inverse(at) - local i2 = torch.inverse(at:cuda()) + local i2 = torch.inverse(a:cuda()) tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong inverse answer") end end @@ -2439,14 +2439,17 @@ if cutorch.magma then } A = A * A:t() - for _, triarg in ipairs({'U', 'L'}) do - local chol = torch.potrf(A, triarg) - - local i1 = torch.potri(chol, triarg) - local i2 = torch.potri(chol:cuda(), triarg) - local M = A:cuda() * i2 - tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potri answer") - tester:assertle((M - torch.eye(A:size(1)):cuda()):abs():max(), 1e-5, "potri not an inverse") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = A:type(typename) + for _, triarg in ipairs({'U', 'L'}) do + local chol = torch.potrf(at, triarg) + + local i1 = torch.potri(chol, triarg) + local i2 = torch.potri(chol:cuda(), triarg) + local M = at:cuda() * i2 + tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potri answer") + tester:assertle((M - torch.eye(at:size(1)):cuda()):abs():max(), 1e-5, "potri not an inverse") + end end end @@ -2458,10 +2461,13 @@ if cutorch.magma then {-0.6738, 0.4734,-1.1123, 2.4071,-1.2756}, {-3.3883, 0.2807, 0.8161,-1.2756, 4.3415}, } - for _, triarg in ipairs({'U', 'L'}) do - local i1 = torch.potrf(A, triarg) - local i2 = torch.potrf(A:cuda(), triarg) - tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potrf answer") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = A:type(typename) + for _, triarg in ipairs({'U', 'L'}) do + local i1 = torch.potrf(at, triarg) + local i2 = torch.potrf(at:cuda(), triarg) + tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potrf answer") + end end end @@ -2478,11 +2484,15 @@ if cutorch.magma then {0.2334, 0.8594, 0.4103}, {0.7556, 0.1966, 0.9637}, {0.1420, 0.7185, 0.7476}}) - for _, triarg in ipairs({'U', 'L'}) do - local chol = torch.potrf(A, triarg) - local solve1 = torch.potrs(B, chol, triarg) - local solve2 = torch.potrs(B:cuda(), chol:cuda(), triarg) - tester:assertle((solve2 - solve1:cuda()):abs():max(), 1e-4, "wrong potrs answer") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = A:type(typename) + local bt = B:type(typename) + for _, triarg in ipairs({'U', 'L'}) do + local chol = torch.potrf(at, triarg) + local solve1 = torch.potrs(bt, chol, triarg) + local solve2 = torch.potrs(bt:cuda(), chol:cuda(), triarg) + tester:assertle((solve2 - solve1:cuda()):abs():max(), 1e-4, "wrong potrs answer") + end end end |