diff options
Diffstat (limited to 'lib/THC/THCBlas.cu')
-rw-r--r-- | lib/THC/THCBlas.cu | 263 |
1 files changed, 180 insertions, 83 deletions
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index 1edbcb0..5b99506 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -1,109 +1,79 @@ #include "THCBlas.h" #include "THCGeneral.h" +#include "THCHalf.h" -void THCudaBlas_swap(THCState *state, long n, float *x, long incx, float *y, long incy) +float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy) { - if(n == 1) - { - incx = 1; - incy = 1; - } - - if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; - THCublasCheck(cublasSswap(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy)); - return; - } - THError("Cublas_swap only supports n, incx and" - " incy upto signed integer limits: %d", INT_MAX); -} - -void THCudaBlas_scal(THCState *state, long n, float a, float *x, long incx) -{ - if(n == 1) - incx = 1; - - if( (n <= INT_MAX) && (incx <= INT_MAX) ) - { - int i_n = (int)n; - int i_incx = (int)incx; - THCublasCheck(cublasSscal(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx)); - return; - } - THError("Cublas_scal only supports n and incx " - "upto signed integer limits: %d", INT_MAX); -} - -void THCudaBlas_copy(THCState *state, long n, float *x, long incx, float *y, long incy) -{ - if(n == 1) - { + if (n == 1) { incx = 1; incy = 1; } - if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { + if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; - THCublasCheck(cublasScopy(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy)); - return; + float result; + THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result)); + return result; } - THError("Cublas_copy only supports n, incx and incy " - "upto signed integer limits: %d", INT_MAX); + THError("Cublas_Sdot only supports n, incx and incy " + "up to signed integer limits: %d", INT_MAX); + return 0; } -void THCudaBlas_axpy(THCState *state, long n, float a, float *x, long incx, float *y, long incy) +double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy) { - if(n == 1) - { + if (n == 1) { incx = 1; incy = 1; } - if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { + if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; - THCublasCheck(cublasSaxpy(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx, y, i_incy)); - return; + double result; + THCublasCheck(cublasDdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result)); + return result; } - THError("Cublas_axpy only supports n, incx and incy " - "upto signed integer limits: %d", INT_MAX); + THError("Cublas_Ddot only supports n, incx and incy " + "up to signed integer limits: %d", INT_MAX); + return 0; } -float THCudaBlas_dot(THCState *state, long n, float *x, long incx, float *y, long incy) +/* Level 2 */ +void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy) { if(n == 1) - { - incx = 1; - incy = 1; - } + lda = m; - if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + cublasOperation_t op; + if (trans == 't') op = CUBLAS_OP_T; + else if (trans == 'n') op = CUBLAS_OP_N; + else if (trans == 'c') op = CUBLAS_OP_C; + + if( (m <= INT_MAX) && (n <= INT_MAX) && + (lda > 0) && (lda <= INT_MAX) && + (incx > 0) && (incx <= INT_MAX) && + (incy > 0) && (incy <= INT_MAX) ) { + int i_m = (int)m; int i_n = (int)n; + int i_lda = (int)lda; int i_incx = (int)incx; int i_incy = (int)incy; - float result; - THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result)); - cudaDeviceSynchronize(); - return result; + + THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy)); + return; } - THError("Cublas_dot only supports n, incx and incy " - "upto signed integer limits: %d", INT_MAX); - return -1; + THError("Cublas_Sgemv only supports m, n, lda, incx, incy" + "in the range 0 < [val] <= %d", INT_MAX); } -/* Level 2 */ -void THCudaBlas_gemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy) +void THCudaBlas_Dgemv(THCState *state, char trans, long m, long n, double alpha, double *a, long lda, double *x, long incx, double beta, double *y, long incy) { if(n == 1) lda = m; @@ -124,14 +94,14 @@ void THCudaBlas_gemv(THCState *state, char trans, long m, long n, float alpha, f int i_incx = (int)incx; int i_incy = (int)incy; - THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy)); + THCublasCheck(cublasDgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy)); return; } - THError("Cublas_gemv only supports m, n, lda, incx, incy" + THError("Cublas_Dgemv only supports m, n, lda, incx, incy" "in the range 0 < [val] <= %d", INT_MAX); } -void THCudaBlas_ger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda) +void THCudaBlas_Sger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda) { if(n == 1) lda = m; @@ -147,10 +117,31 @@ void THCudaBlas_ger(THCState *state, long m, long n, float alpha, float *x, long THCublasCheck(cublasSger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); return; } - THError("Cublas_ger only supports m, n, lda, incx, incy" + THError("Cublas_Sger only supports m, n, lda, incx, incy" "with the bound [val] <= %d", INT_MAX); } +void THCudaBlas_Dger(THCState *state, long m, long n, double alpha, double *x, long incx, double *y, long incy, double *a, long lda) +{ + if(n == 1) + lda = m; + + if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_lda = (int)lda; + int i_incx = (int)incx; + int i_incy = (int)incy; + + THCublasCheck(cublasDger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); + return; + } + THError("Cublas_Dger only supports m, n, lda, incx, incy" + "with the bound [val] <= %d", INT_MAX); +} + + cublasOperation_t convertTransToCublasOperation(char trans) { if (trans == 't') return CUBLAS_OP_T; else if (trans == 'n') return CUBLAS_OP_N; @@ -193,7 +184,7 @@ void adjustLd(char transa, char transb, long m, long n, long k, long *lda, long } /* Level 3 */ -void THCudaBlas_gemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc) +void THCudaBlas_Sgemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc) { adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc); cublasOperation_t opa = convertTransToCublasOperation(transa); @@ -211,17 +202,84 @@ void THCudaBlas_gemm(THCState *state, char transa, char transb, long m, long n, THCublasCheck(cublasSgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc)); return; } - THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc" + THError("Cublas_Sgemm only supports m, n, k, lda, ldb, ldc" "with the bound [val] <= %d", INT_MAX); } -void THCudaBlas_gemmBatched(THCState *state, char transa, char transb, long m, long n, long k, - float alpha, const float *a[], long lda, const float *b[], long ldb, - float beta, float *c[], long ldc, long batchCount) +#ifdef CUDA_HALF_TENSOR +// In CUDA 8.0, definition of data types for sgemmex changed +#if CUDA_VERSION < 8000 +# define CUDA_R_16F CUBLAS_DATA_HALF +#endif + +void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, long k, half alpha, half *a, long lda, half *b, long ldb, half beta, half *c, long ldc) +{ + adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc); + cublasOperation_t opa = convertTransToCublasOperation(transa); + cublasOperation_t opb = convertTransToCublasOperation(transb); + + if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; + + // Check for native Hgemm support + if (THC_nativeHalfInstructions(state)) { + THCublasCheck(cublasHgemm(THCState_getCurrentBlasHandle(state), opa, opb, + i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, + &beta, c, i_ldc)); + } else { + // Simulated Hgemm + float fAlpha = THC_half2float(alpha); + float fBeta = THC_half2float(beta); + + THCublasCheck(cublasSgemmEx(THCState_getCurrentBlasHandle(state), opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + } + + return; + } + THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc" + "with th bound [val] <= %d", INT_MAX); +} +#endif + +void THCudaBlas_Dgemm(THCState *state, char transa, char transb, long m, long n, long k, double alpha, double *a, long lda, double *b, long ldb, double beta, double *c, long ldc) +{ + adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc); + cublasOperation_t opa = convertTransToCublasOperation(transa); + cublasOperation_t opb = convertTransToCublasOperation(transb); + + if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; + + THCublasCheck(cublasDgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc)); + return; + } + THError("Cublas_Dgemm only supports m, n, k, lda, ldb, ldc" + "with the bound [val] <= %d", INT_MAX); +} + + +void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m, long n, long k, + float alpha, const float *a[], long lda, const float *b[], long ldb, + float beta, float *c[], long ldc, long batchCount) { if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) { - THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc, batchCount" + THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" "with the bound [val] <= %d", INT_MAX); } @@ -235,22 +293,61 @@ void THCudaBlas_gemmBatched(THCState *state, char transa, char transb, long m, l (int)batchCount)); } +void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k, + double alpha, const double *a[], long lda, const double *b[], long ldb, + double beta, double *c[], long ldc, long batchCount) +{ + if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) + { + THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" + "with the bound [val] <= %d", INT_MAX); + } + + adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc); + cublasOperation_t opa = convertTransToCublasOperation(transa); + cublasOperation_t opb = convertTransToCublasOperation(transb); + + THCublasCheck(cublasDgemmBatched(THCState_getCurrentBlasHandle(state), + opa, opb, (int)m, (int)n, (int)k, + &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, + (int)batchCount)); +} + /* Inverse */ -void THCudaBlas_getrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) { +void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) { if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) ) { - THError("Cublas_getrf only supports n, lda, batchSize" + THError("Cublas_Sgetrf only supports n, lda, batchSize" "with the bound [val] <= %d", INT_MAX); } THCublasCheck(cublasSgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize)); } -void THCudaBlas_getri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) { +void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) { + if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) ) + { + THError("Cublas_Dgetrf only supports n, lda, batchSize" + "with the bound [val] <= %d", INT_MAX); + } + THCublasCheck(cublasDgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize)); +} + +void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) { if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) ) { - THError("Cublas_getrf only supports n, lda, ldc, batchSize" + THError("Cublas_Sgetri only supports n, lda, ldc, batchSize" "with the bound [val] <= %d", INT_MAX); } THCublasCheck(cublasSgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize)); } + +void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) { + + if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) ) + { + THError("Cublas_Dgetri only supports n, lda, ldc, batchSize" + "with the bound [val] <= %d", INT_MAX); + } + THCublasCheck(cublasDgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize)); +} |