Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-14 22:35:29 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:30:49 +0300
commit24c3448ca942508e886763b312c30efe3a50beca (patch)
treea6d22910e446751830d410b85b785b44ff1ae48b
parent30885a31b4c99244a8c79e1c98a531b6261530be (diff)
[cutorch mag2gen] generic MAGMA memory allocator function
-rw-r--r--lib/THC/THCTensorMathMagma.cu53
1 files changed, 23 insertions, 30 deletions
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index 325278e..a616302 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,20 +23,13 @@ void THCMagma_init(THCState *state)
}
#ifdef USE_MAGMA
-static inline float* th_magma_smalloc_pinned(size_t n)
+template <typename T>
+static inline T* th_magma_malloc_pinned(size_t n)
{
- float* ptr;
- if (MAGMA_SUCCESS != magma_smalloc_pinned(&ptr, n))
+ void* ptr;
+ if (MAGMA_SUCCESS != magma_malloc_pinned(&ptr, n * sizeof(T)))
THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", n/268435456);
- return ptr;
-}
-
-static inline int* th_magma_imalloc_pinned(size_t n)
-{
- int* ptr;
- if (MAGMA_SUCCESS != magma_imalloc_pinned(&ptr, n))
- THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", n/268435456);
- return ptr;
+ return reinterpret_cast<T*>(ptr);
}
static void THCudaTensor_copyArray1d(THCState *state, THCudaTensor *self, float *src, int k)
@@ -109,7 +102,7 @@ void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, TH
float *a_data = THCudaTensor_data(state, a);
float *b_data = THCudaTensor_data(state, b);
- int *ipiv = th_magma_imalloc_pinned(n);
+ int *ipiv = th_magma_malloc_pinned<int>(n);
int info;
magma_sgesv_gpu(n, nrhs, a_data, n, ipiv, b_data, n, &info);
@@ -148,7 +141,7 @@ void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, TH
int info;
magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
- float *hwork = th_magma_smalloc_pinned((size_t)wkopt);
+ float *hwork = th_magma_malloc_pinned<float>((size_t)wkopt);
magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
magma_free_pinned(hwork);
@@ -175,8 +168,8 @@ void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, TH
float *input_data = THCudaTensor_data(state, input);
// eigen values and workspace
- float *w = th_magma_smalloc_pinned(n);
- float *wA = th_magma_smalloc_pinned(lda);
+ float *w = th_magma_malloc_pinned<float>(n);
+ float *wA = th_magma_malloc_pinned<float>(lda);
// compute optimal size of work array
int info;
@@ -184,8 +177,8 @@ void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, TH
int liwork;
magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, &lwork, -1, &liwork, -1, &info);
- float *work = th_magma_smalloc_pinned((size_t)lwork);
- int *iwork = th_magma_imalloc_pinned(liwork);
+ float *work = th_magma_malloc_pinned<float>((size_t)lwork);
+ int *iwork = th_magma_malloc_pinned<int>(liwork);
// compute eigenvalues and, optionally, eigenvectors
magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, work, (int) lwork, iwork, liwork, &info);
@@ -220,17 +213,17 @@ void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, TH
magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec;
int n = a_->size[0];
- float *a_data = th_magma_smalloc_pinned(n * n);
+ float *a_data = th_magma_malloc_pinned<float>(n * n);
THCudaTensor_copyTensor2d(state, a_data, a_);
- float *wr = th_magma_smalloc_pinned(n);
- float *wi = th_magma_smalloc_pinned(n);
+ float *wr = th_magma_malloc_pinned<float>(n);
+ float *wi = th_magma_malloc_pinned<float>(n);
float *vr_data = NULL;
int ldvr = 1;
if (jobvr == MagmaVec)
{
- vr_data = th_magma_smalloc_pinned(n * n);
+ vr_data = th_magma_malloc_pinned<float>(n * n);
ldvr = n;
}
@@ -240,7 +233,7 @@ void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, TH
magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
int lwork = (int) wkopt;
- float *work_data = th_magma_smalloc_pinned(lwork);
+ float *work_data = th_magma_malloc_pinned<float>(lwork);
magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
@@ -296,19 +289,19 @@ void THCudaTensor_gesvd2(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_,
int k = m < n ? m : n;
int j = (jobu == MagmaAllVec) ? m : k;
- float *a_data = th_magma_smalloc_pinned(m * n);
+ float *a_data = th_magma_malloc_pinned<float>(m * n);
THCudaTensor_copyTensor2d(state, a_data, a);
- float *rs_data = th_magma_smalloc_pinned(k);
- float *ru_data = th_magma_smalloc_pinned(m * j);
- float *rv_data = th_magma_smalloc_pinned(n * n);
+ float *rs_data = th_magma_malloc_pinned<float>(k);
+ float *ru_data = th_magma_malloc_pinned<float>(m * j);
+ float *rv_data = th_magma_malloc_pinned<float>(n * n);
float wkopt;
int info;
magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, &info);
int lwork = (int) wkopt;
- float *work_data = th_magma_smalloc_pinned(lwork);
+ float *work_data = th_magma_malloc_pinned<float>(lwork);
magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, &info);
@@ -346,7 +339,7 @@ void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a)
THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a);
float *input_data = THCudaTensor_data(state, input);
- int *ipiv = th_magma_imalloc_pinned(n);
+ int *ipiv = th_magma_malloc_pinned<int>(n);
THCudaTensor *work = THCudaTensor_newWithSize1d(state, lwork);
float *work_data = THCudaTensor_data(state, work);
@@ -566,7 +559,7 @@ void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCu
#endif
float *a_data = THCudaTensor_data(state, a);
- float *tau_data = th_magma_smalloc_pinned(n*n);
+ float *tau_data = th_magma_malloc_pinned<float>(n*n);
THCudaTensor *work = THCudaTensor_newWithSize1d(state, (2*k + ((n+31)/32)*32)*nb);
float *work_data = THCudaTensor_data(state, work);