Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-01 18:25:50 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-01 18:25:50 +0300
commit668aa6331cb9a3dc3ebe630788c2968bf38ffe90 (patch)
tree9fbe538af34d11688d96f95e0a0140fe8a6f2df2
parent5507b1756cc4bb2879065da0255131ae74069471 (diff)
custom Transpose. cublasHgeam does not exist
-rw-r--r--src/amun/half/mblas/matrix_functions.cu32
1 files changed, 26 insertions, 6 deletions
diff --git a/src/amun/half/mblas/matrix_functions.cu b/src/amun/half/mblas/matrix_functions.cu
index e0595e05..ee196bbd 100644
--- a/src/amun/half/mblas/matrix_functions.cu
+++ b/src/amun/half/mblas/matrix_functions.cu
@@ -172,18 +172,38 @@ void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const mbl
*/
}
-Matrix& Transpose(Matrix& Out, const Matrix& In) {
+/////////////////////////////////////////////////////////////////////////////
+__global__ void gTranspose(MatrixWrapper<half> out, const MatrixWrapper<half> in)
+{
+ int id = threadIdx.x + blockIdx.x * blockDim.x;
+ //printf("id = %d in = %lu %lu %lu %lu = %lu %lu \n", id, in.dim(0), in.dim(1), in.dim(2), in.dim(3), in.size(), sizeof(in));
+
+ if (id < in.size()) {
+ uint indices[SHAPE_SIZE];
+ in.id2Indices(id, indices);
+
+ out(indices[1], indices[0], 0, 0) = in(indices[0], indices[1], 0, 0);
+ }
+}
+
+Matrix& Transpose(Matrix& Out, const Matrix& In)
+{
+ assert(In.dim(2) == 1);
+ assert(In.dim(3) == 1);
size_t m = In.dim(0);
size_t n = In.dim(1);
Out.NewSize(n, m);
+ //cerr << "In=" << In.Debug(0) << endl;
+ //cerr << "Out=" << Out.Debug(0) << endl;
- //HH
- //half alpha = 1.0;
- //half beta = 0.0;
+ MatrixWrapper<half> outWrap(Out);
+ const MatrixWrapper<half> inWrap(In);
+
+ int nThreads = MAX_THREADS;
+ int nBlocks = (In.size() / nThreads) + ((In.size() % nThreads == 0) ? 0 : 1);
- //cublasHgeam(CublasHandler::GetHandle(), CUBLAS_OP_T, CUBLAS_OP_T, m, n, &alpha, In.data(), n,
- // &beta, In.data(), n, Out.data(), m);
+ gTranspose<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>>(outWrap, inWrap);
return Out;
}