diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-01 18:25:50 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-01 18:25:50 +0300 |
commit | 668aa6331cb9a3dc3ebe630788c2968bf38ffe90 (patch) | |
tree | 9fbe538af34d11688d96f95e0a0140fe8a6f2df2 | |
parent | 5507b1756cc4bb2879065da0255131ae74069471 (diff) |
custom Transpose. cublasHgeam does not exist
-rw-r--r-- | src/amun/half/mblas/matrix_functions.cu | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/src/amun/half/mblas/matrix_functions.cu b/src/amun/half/mblas/matrix_functions.cu index e0595e05..ee196bbd 100644 --- a/src/amun/half/mblas/matrix_functions.cu +++ b/src/amun/half/mblas/matrix_functions.cu @@ -172,18 +172,38 @@ void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const mbl */ } -Matrix& Transpose(Matrix& Out, const Matrix& In) { +///////////////////////////////////////////////////////////////////////////// +__global__ void gTranspose(MatrixWrapper<half> out, const MatrixWrapper<half> in) +{ + int id = threadIdx.x + blockIdx.x * blockDim.x; + //printf("id = %d in = %lu %lu %lu %lu = %lu %lu \n", id, in.dim(0), in.dim(1), in.dim(2), in.dim(3), in.size(), sizeof(in)); + + if (id < in.size()) { + uint indices[SHAPE_SIZE]; + in.id2Indices(id, indices); + + out(indices[1], indices[0], 0, 0) = in(indices[0], indices[1], 0, 0); + } +} + +Matrix& Transpose(Matrix& Out, const Matrix& In) +{ + assert(In.dim(2) == 1); + assert(In.dim(3) == 1); size_t m = In.dim(0); size_t n = In.dim(1); Out.NewSize(n, m); + //cerr << "In=" << In.Debug(0) << endl; + //cerr << "Out=" << Out.Debug(0) << endl; - //HH - //half alpha = 1.0; - //half beta = 0.0; + MatrixWrapper<half> outWrap(Out); + const MatrixWrapper<half> inWrap(In); + + int nThreads = MAX_THREADS; + int nBlocks = (In.size() / nThreads) + ((In.size() % nThreads == 0) ? 0 : 1); - //cublasHgeam(CublasHandler::GetHandle(), CUBLAS_OP_T, CUBLAS_OP_T, m, n, &alpha, In.data(), n, - // &beta, In.data(), n, Out.data(), m); + gTranspose<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>>(outWrap, inWrap); return Out; } |