diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-11-18 01:33:31 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-18 01:33:31 +0300 |
commit | dd86d97777e3cbca33ffbb0eb8a333abe522e44e (patch) | |
tree | c6b1fd6bc5793c61e948df640aa7b8b29220de3f | |
parent | 0afffe1e6d3c115f5e830dec999349fe38bdb06b (diff) | |
parent | cd8e20962ae596c03134c5abde475c6ad8760e55 (diff) |
Merge pull request #605 from gchanan/halfAddrAddmv
Add half support for addmv and addr.
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index 18ca290..63c9989 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -43,7 +43,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) THC_API void THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec) { -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THAssert(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec)); if( (mat->nDimension != 2) || (vec->nDimension != 1) ) THError("matrix and vector expected"); @@ -57,6 +57,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real if(t->size[0] != mat->size[0]) THError("size mismatch"); +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) if(r_ != t) { THCTensor_(resizeAs)(state, r_, t); @@ -110,6 +111,21 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real THCTensor_(free)(state, cmat); } +#elif defined(THC_REAL_IS_HALF) + // Currently no Hgemv/SgemvEx in Cublas + THCTensor *vecAsMatrix = THCTensor_(newWithTensor)(state, vec); + THCTensor_(resize2d)(state, vecAsMatrix, vecAsMatrix->size[0], 1); + + THCTensor *tAsMatrix = THCTensor_(newWithTensor)(state, t); + THCTensor_(resize2d)(state, tAsMatrix, tAsMatrix->size[0], 1); + + THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix); + + // r_ will have answer as matrix, need to return a vecotr + THCTensor_(resize1d)(state, r_, r_->size[0]); + THCTensor_(free)(state, vecAsMatrix); + THCTensor_(free)(state, tAsMatrix); +#endif #else THError("unimplemented data type"); #endif @@ -118,7 +134,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real THC_API void THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2) { -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THAssert(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2)); if ( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) { THError("vector and vector expected"); @@ -132,12 +148,13 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a THError("size mismatch"); } +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) if (r_ != t) { THCTensor_(resizeAs)(state, r_, t); THCTensor_(copy)(state, r_, t); } - if(beta != 1) { + if(THCNumerics<real>::ne(beta, ScalarConvert<int, real>::to(1))) { THCTensor_(mul)(state, r_, r_, beta); } @@ -187,6 +204,19 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a THCTensor_(freeCopyTo)(state, cr, r_); } +#elif defined(THC_REAL_IS_HALF) + // currently no Hger/SgerEx in Cublas. + THCTensor *vec2T = THCTensor_(newWithTensor)(state, vec2); + THCTensor_(resize2d)(state, vec2T, vec2T->size[0], 1); + THCTensor_(transpose)(state, vec2T, NULL, 0, 1); + + THCTensor *vec1M = THCTensor_(newWithTensor)(state, vec1); + THCTensor_(resize2d)(state, vec1M, vec1M->size[0], 1); + + THCTensor_(addmm)(state, r_, beta, t, alpha, vec1M, vec2T); + THCTensor_(free)(state, vec2T); + THCTensor_(free)(state, vec1M); +#endif #else THError("unimplemented data type"); #endif |