Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCBlas.cu')
-rw-r--r--lib/THC/THCBlas.cu263
1 files changed, 180 insertions, 83 deletions
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu
index 1edbcb0..5b99506 100644
--- a/lib/THC/THCBlas.cu
+++ b/lib/THC/THCBlas.cu
@@ -1,109 +1,79 @@
#include "THCBlas.h"
#include "THCGeneral.h"
+#include "THCHalf.h"
-void THCudaBlas_swap(THCState *state, long n, float *x, long incx, float *y, long incy)
+float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy)
{
- if(n == 1)
- {
- incx = 1;
- incy = 1;
- }
-
- if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
- {
- int i_n = (int)n;
- int i_incx = (int)incx;
- int i_incy = (int)incy;
- THCublasCheck(cublasSswap(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy));
- return;
- }
- THError("Cublas_swap only supports n, incx and"
- " incy upto signed integer limits: %d", INT_MAX);
-}
-
-void THCudaBlas_scal(THCState *state, long n, float a, float *x, long incx)
-{
- if(n == 1)
- incx = 1;
-
- if( (n <= INT_MAX) && (incx <= INT_MAX) )
- {
- int i_n = (int)n;
- int i_incx = (int)incx;
- THCublasCheck(cublasSscal(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx));
- return;
- }
- THError("Cublas_scal only supports n and incx "
- "upto signed integer limits: %d", INT_MAX);
-}
-
-void THCudaBlas_copy(THCState *state, long n, float *x, long incx, float *y, long incy)
-{
- if(n == 1)
- {
+ if (n == 1) {
incx = 1;
incy = 1;
}
- if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
- {
+ if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
- THCublasCheck(cublasScopy(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy));
- return;
+ float result;
+ THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
+ return result;
}
- THError("Cublas_copy only supports n, incx and incy "
- "upto signed integer limits: %d", INT_MAX);
+ THError("Cublas_Sdot only supports n, incx and incy "
+ "up to signed integer limits: %d", INT_MAX);
+ return 0;
}
-void THCudaBlas_axpy(THCState *state, long n, float a, float *x, long incx, float *y, long incy)
+double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy)
{
- if(n == 1)
- {
+ if (n == 1) {
incx = 1;
incy = 1;
}
- if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
- {
+ if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
- THCublasCheck(cublasSaxpy(THCState_getCurrentBlasHandle(state), i_n, &a, x, i_incx, y, i_incy));
- return;
+ double result;
+ THCublasCheck(cublasDdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
+ return result;
}
- THError("Cublas_axpy only supports n, incx and incy "
- "upto signed integer limits: %d", INT_MAX);
+ THError("Cublas_Ddot only supports n, incx and incy "
+ "up to signed integer limits: %d", INT_MAX);
+ return 0;
}
-float THCudaBlas_dot(THCState *state, long n, float *x, long incx, float *y, long incy)
+/* Level 2 */
+void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy)
{
if(n == 1)
- {
- incx = 1;
- incy = 1;
- }
+ lda = m;
- if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ cublasOperation_t op;
+ if (trans == 't') op = CUBLAS_OP_T;
+ else if (trans == 'n') op = CUBLAS_OP_N;
+ else if (trans == 'c') op = CUBLAS_OP_C;
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) &&
+ (lda > 0) && (lda <= INT_MAX) &&
+ (incx > 0) && (incx <= INT_MAX) &&
+ (incy > 0) && (incy <= INT_MAX) )
{
+ int i_m = (int)m;
int i_n = (int)n;
+ int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
- float result;
- THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
- cudaDeviceSynchronize();
- return result;
+
+ THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
+ return;
}
- THError("Cublas_dot only supports n, incx and incy "
- "upto signed integer limits: %d", INT_MAX);
- return -1;
+ THError("Cublas_Sgemv only supports m, n, lda, incx, incy"
+ "in the range 0 < [val] <= %d", INT_MAX);
}
-/* Level 2 */
-void THCudaBlas_gemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy)
+void THCudaBlas_Dgemv(THCState *state, char trans, long m, long n, double alpha, double *a, long lda, double *x, long incx, double beta, double *y, long incy)
{
if(n == 1)
lda = m;
@@ -124,14 +94,14 @@ void THCudaBlas_gemv(THCState *state, char trans, long m, long n, float alpha, f
int i_incx = (int)incx;
int i_incy = (int)incy;
- THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
+ THCublasCheck(cublasDgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
- THError("Cublas_gemv only supports m, n, lda, incx, incy"
+ THError("Cublas_Dgemv only supports m, n, lda, incx, incy"
"in the range 0 < [val] <= %d", INT_MAX);
}
-void THCudaBlas_ger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda)
+void THCudaBlas_Sger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda)
{
if(n == 1)
lda = m;
@@ -147,10 +117,31 @@ void THCudaBlas_ger(THCState *state, long m, long n, float alpha, float *x, long
THCublasCheck(cublasSger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
- THError("Cublas_ger only supports m, n, lda, incx, incy"
+ THError("Cublas_Sger only supports m, n, lda, incx, incy"
"with the bound [val] <= %d", INT_MAX);
}
+void THCudaBlas_Dger(THCState *state, long m, long n, double alpha, double *x, long incx, double *y, long incy, double *a, long lda)
+{
+ if(n == 1)
+ lda = m;
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_lda = (int)lda;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+
+ THCublasCheck(cublasDger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
+ return;
+ }
+ THError("Cublas_Dger only supports m, n, lda, incx, incy"
+ "with the bound [val] <= %d", INT_MAX);
+}
+
+
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N;
@@ -193,7 +184,7 @@ void adjustLd(char transa, char transb, long m, long n, long k, long *lda, long
}
/* Level 3 */
-void THCudaBlas_gemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc)
+void THCudaBlas_Sgemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc)
{
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
@@ -211,17 +202,84 @@ void THCudaBlas_gemm(THCState *state, char transa, char transb, long m, long n,
THCublasCheck(cublasSgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
- THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc"
+ THError("Cublas_Sgemm only supports m, n, k, lda, ldb, ldc"
"with the bound [val] <= %d", INT_MAX);
}
-void THCudaBlas_gemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
- float alpha, const float *a[], long lda, const float *b[], long ldb,
- float beta, float *c[], long ldc, long batchCount)
+#ifdef CUDA_HALF_TENSOR
+// In CUDA 8.0, definition of data types for sgemmex changed
+#if CUDA_VERSION < 8000
+# define CUDA_R_16F CUBLAS_DATA_HALF
+#endif
+
+void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, long k, half alpha, half *a, long lda, half *b, long ldb, half beta, half *c, long ldc)
+{
+ adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
+ cublasOperation_t opa = convertTransToCublasOperation(transa);
+ cublasOperation_t opb = convertTransToCublasOperation(transb);
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_k = (int)k;
+ int i_lda = (int)lda;
+ int i_ldb = (int)ldb;
+ int i_ldc = (int)ldc;
+
+ // Check for native Hgemm support
+ if (THC_nativeHalfInstructions(state)) {
+ THCublasCheck(cublasHgemm(THCState_getCurrentBlasHandle(state), opa, opb,
+ i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb,
+ &beta, c, i_ldc));
+ } else {
+ // Simulated Hgemm
+ float fAlpha = THC_half2float(alpha);
+ float fBeta = THC_half2float(beta);
+
+ THCublasCheck(cublasSgemmEx(THCState_getCurrentBlasHandle(state), opa, opb,
+ i_m, i_n, i_k, &fAlpha,
+ a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
+ i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
+ }
+
+ return;
+ }
+ THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc"
+ "with th bound [val] <= %d", INT_MAX);
+}
+#endif
+
+void THCudaBlas_Dgemm(THCState *state, char transa, char transb, long m, long n, long k, double alpha, double *a, long lda, double *b, long ldb, double beta, double *c, long ldc)
+{
+ adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
+ cublasOperation_t opa = convertTransToCublasOperation(transa);
+ cublasOperation_t opb = convertTransToCublasOperation(transb);
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_k = (int)k;
+ int i_lda = (int)lda;
+ int i_ldb = (int)ldb;
+ int i_ldc = (int)ldc;
+
+ THCublasCheck(cublasDgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
+ return;
+ }
+ THError("Cublas_Dgemm only supports m, n, k, lda, ldb, ldc"
+ "with the bound [val] <= %d", INT_MAX);
+}
+
+
+void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
+ float alpha, const float *a[], long lda, const float *b[], long ldb,
+ float beta, float *c[], long ldc, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
- THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc, batchCount"
+ THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
@@ -235,22 +293,61 @@ void THCudaBlas_gemmBatched(THCState *state, char transa, char transb, long m, l
(int)batchCount));
}
+void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
+ double alpha, const double *a[], long lda, const double *b[], long ldb,
+ double beta, double *c[], long ldc, long batchCount)
+{
+ if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
+ {
+ THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
+ "with the bound [val] <= %d", INT_MAX);
+ }
+
+ adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
+ cublasOperation_t opa = convertTransToCublasOperation(transa);
+ cublasOperation_t opb = convertTransToCublasOperation(transb);
+
+ THCublasCheck(cublasDgemmBatched(THCState_getCurrentBlasHandle(state),
+ opa, opb, (int)m, (int)n, (int)k,
+ &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
+ (int)batchCount));
+}
+
/* Inverse */
-void THCudaBlas_getrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
+void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
{
- THError("Cublas_getrf only supports n, lda, batchSize"
+ THError("Cublas_Sgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasSgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize));
}
-void THCudaBlas_getri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
+void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) {
+ if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
+ {
+ THError("Cublas_Dgetrf only supports n, lda, batchSize"
+ "with the bound [val] <= %d", INT_MAX);
+ }
+ THCublasCheck(cublasDgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize));
+}
+
+void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
{
- THError("Cublas_getrf only supports n, lda, ldc, batchSize"
+ THError("Cublas_Sgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasSgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize));
}
+
+void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) {
+
+ if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
+ {
+ THError("Cublas_Dgetri only supports n, lda, ldc, batchSize"
+ "with the bound [val] <= %d", INT_MAX);
+ }
+ THCublasCheck(cublasDgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize));
+}