diff options
Diffstat (limited to 'lib/TH/generic/THBlas.c')
-rw-r--r-- | lib/TH/generic/THBlas.c | 105 |
1 files changed, 46 insertions, 59 deletions
diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c index 195e655..371df4d 100644 --- a/lib/TH/generic/THBlas.c +++ b/lib/TH/generic/THBlas.c @@ -9,37 +9,24 @@ # define ffloat float #endif -// define MKL_LP64 to get 32bit ints on 64bit platforms -#ifndef MKL_LP64 - // 64bit ints - #ifdef WIN32 - #define BLAS_INT __int64 - #else - #define BLAS_INT long - #endif -#else - // 32bit ints - #define BLAS_INT int -#endif +TH_EXTERNC void dswap_(int *n, double *x, int *incx, double *y, int *incy); +TH_EXTERNC void sswap_(int *n, float *x, int *incx, float *y, int *incy); +TH_EXTERNC void dscal_(int *n, double *a, double *x, int *incx); +TH_EXTERNC void sscal_(int *n, float *a, float *x, int *incx); +TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy); +TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy); +TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy); +TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy); +TH_EXTERNC double ddot_(int *n, double *x, int *incx, double *y, int *incy); +TH_EXTERNC ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy); +TH_EXTERNC void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); +TH_EXTERNC void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); +TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda); +TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda); +TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc); +TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc); -TH_EXTERNC void dswap_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); -TH_EXTERNC void sswap_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); -TH_EXTERNC void dscal_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx); -TH_EXTERNC void sscal_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx); -TH_EXTERNC void dcopy_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); -TH_EXTERNC void scopy_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); -TH_EXTERNC void daxpy_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); -TH_EXTERNC void saxpy_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); -TH_EXTERNC double ddot_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); -TH_EXTERNC ffloat sdot_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); -TH_EXTERNC void dgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, double *alpha, double *a, BLAS_INT *lda, double *x, BLAS_INT *incx, double *beta, double *y, BLAS_INT *incy); -TH_EXTERNC void sgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, float *alpha, float *a, BLAS_INT *lda, float *x, BLAS_INT *incx, float *beta, float *y, BLAS_INT *incy); -TH_EXTERNC void dger_(BLAS_INT *m, BLAS_INT *n, double *alpha, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy, double *a, BLAS_INT *lda); -TH_EXTERNC void sger_(BLAS_INT *m, BLAS_INT *n, float *alpha, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy, float *a, BLAS_INT *lda); -TH_EXTERNC void dgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, double *alpha, double *a, BLAS_INT *lda, double *b, BLAS_INT *ldb, double *beta, double *c, BLAS_INT *ldc); -TH_EXTERNC void sgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, float *alpha, float *a, BLAS_INT *lda, float *b, BLAS_INT *ldb, float *beta, float *c, BLAS_INT *ldc); - void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) { @@ -52,9 +39,9 @@ void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) dswap_(&i_n, x, &i_incx, y, &i_incy); @@ -83,8 +70,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) ) { - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_incx = (BLAS_INT)incx; + int i_n = (int)n; + int i_incx = (int)incx; #if defined(TH_REAL_IS_DOUBLE) dscal_(&i_n, &a, x, &i_incx); @@ -112,9 +99,9 @@ void THBlas_(copy)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) dcopy_(&i_n, x, &i_incx, y, &i_incy); @@ -142,9 +129,9 @@ void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); @@ -172,9 +159,9 @@ real THBlas_(dot)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) return (real) ddot_(&i_n, x, &i_incx, y, &i_incy); @@ -203,11 +190,11 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re (incx > 0) && (incx <= INT_MAX) && (incy > 0) && (incy <= INT_MAX) ) { - BLAS_INT i_m = (BLAS_INT)m; - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_lda = (BLAS_INT)lda; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_m = (int)m; + int i_n = (int)n; + int i_lda = (int)lda; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); @@ -258,11 +245,11 @@ void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - BLAS_INT i_m = (BLAS_INT)m; - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_lda = (BLAS_INT)lda; - BLAS_INT i_incx = (BLAS_INT)incx; - BLAS_INT i_incy = (BLAS_INT)incy; + int i_m = (int)m; + int i_n = (int)n; + int i_lda = (int)lda; + int i_incx = (int)incx; + int i_incy = (int)incy; #if defined(TH_REAL_IS_DOUBLE) dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); @@ -317,12 +304,12 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) { - BLAS_INT i_m = (BLAS_INT)m; - BLAS_INT i_n = (BLAS_INT)n; - BLAS_INT i_k = (BLAS_INT)k; - BLAS_INT i_lda = (BLAS_INT)lda; - BLAS_INT i_ldb = (BLAS_INT)ldb; - BLAS_INT i_ldc = (BLAS_INT)ldc; + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; #if defined(TH_REAL_IS_DOUBLE) dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); |