Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-18 01:33:31 +0300
committerGitHub <noreply@github.com>2016-11-18 01:33:31 +0300
commitdd86d97777e3cbca33ffbb0eb8a333abe522e44e (patch)
treec6b1fd6bc5793c61e948df640aa7b8b29220de3f
parent0afffe1e6d3c115f5e830dec999349fe38bdb06b (diff)
parentcd8e20962ae596c03134c5abde475c6ad8760e55 (diff)
Merge pull request #605 from gchanan/halfAddrAddmv
Add half support for addmv and addr.
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu36
1 files changed, 33 insertions, 3 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index 18ca290..63c9989 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -43,7 +43,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
THC_API void
THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec)
{
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
THAssert(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec));
if( (mat->nDimension != 2) || (vec->nDimension != 1) )
THError("matrix and vector expected");
@@ -57,6 +57,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
if(t->size[0] != mat->size[0])
THError("size mismatch");
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
if(r_ != t)
{
THCTensor_(resizeAs)(state, r_, t);
@@ -110,6 +111,21 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THCTensor_(free)(state, cmat);
}
+#elif defined(THC_REAL_IS_HALF)
+ // Currently no Hgemv/SgemvEx in Cublas
+ THCTensor *vecAsMatrix = THCTensor_(newWithTensor)(state, vec);
+ THCTensor_(resize2d)(state, vecAsMatrix, vecAsMatrix->size[0], 1);
+
+ THCTensor *tAsMatrix = THCTensor_(newWithTensor)(state, t);
+ THCTensor_(resize2d)(state, tAsMatrix, tAsMatrix->size[0], 1);
+
+ THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix);
+
+ // r_ will have answer as matrix, need to return a vecotr
+ THCTensor_(resize1d)(state, r_, r_->size[0]);
+ THCTensor_(free)(state, vecAsMatrix);
+ THCTensor_(free)(state, tAsMatrix);
+#endif
#else
THError("unimplemented data type");
#endif
@@ -118,7 +134,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THC_API void
THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2)
{
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
THAssert(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2));
if ( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) {
THError("vector and vector expected");
@@ -132,12 +148,13 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a
THError("size mismatch");
}
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
if (r_ != t) {
THCTensor_(resizeAs)(state, r_, t);
THCTensor_(copy)(state, r_, t);
}
- if(beta != 1) {
+ if(THCNumerics<real>::ne(beta, ScalarConvert<int, real>::to(1))) {
THCTensor_(mul)(state, r_, r_, beta);
}
@@ -187,6 +204,19 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a
THCTensor_(freeCopyTo)(state, cr, r_);
}
+#elif defined(THC_REAL_IS_HALF)
+ // currently no Hger/SgerEx in Cublas.
+ THCTensor *vec2T = THCTensor_(newWithTensor)(state, vec2);
+ THCTensor_(resize2d)(state, vec2T, vec2T->size[0], 1);
+ THCTensor_(transpose)(state, vec2T, NULL, 0, 1);
+
+ THCTensor *vec1M = THCTensor_(newWithTensor)(state, vec1);
+ THCTensor_(resize2d)(state, vec1M, vec1M->size[0], 1);
+
+ THCTensor_(addmm)(state, r_, beta, t, alpha, vec1M, vec2T);
+ THCTensor_(free)(state, vec2T);
+ THCTensor_(free)(state, vec1M);
+#endif
#else
THError("unimplemented data type");
#endif