Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCBlas.h « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 45f58eba5ea22fb8c9e04f09285ab8f8180817cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#ifndef THC_BLAS_INC
#define THC_BLAS_INC

#include "THCGeneral.h"
#include "THCHalf.h"

/* Level 1 */
THC_API float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy);
THC_API double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy);

/* Level 2 */
THC_API void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy);
THC_API void THCudaBlas_Dgemv(THCState *state, char trans, long m, long n, double alpha, double *a, long lda, double *x, long incx, double beta, double *y, long incy);
THC_API void THCudaBlas_Sger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda);
THC_API void THCudaBlas_Dger(THCState *state, long m, long n, double alpha, double *x, long incx, double *y, long incy, double *a, long lda);

/* Level 3 */
THC_API void THCudaBlas_Sgemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc);
THC_API void THCudaBlas_Dgemm(THCState *state, char transa, char transb, long m, long n, long k, double alpha, double *a, long lda, double *b, long ldb, double beta, double *c, long ldc);

#ifdef CUDA_HALF_TENSOR
THC_API void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, long k, half alpha, half *a, long lda, half *b, long ldb, half beta, half *c, long ldc);
#endif

THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
                                     float alpha, const float *a[], long lda, const float *b[], long ldb,
                                     float beta, float *c[], long ldc, long batchCount);
THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
                                     double alpha, const double *a[], long lda, const double *b[], long ldb,
                                     double beta, double *c[], long ldc, long batchCount);

/* Inverse */
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);
THC_API void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize);
THC_API void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize);

#endif