Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-01 19:18:07 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-01 19:18:07 +0300
commiteaa22f49b79d8ed8dfd299822758430218628ad2 (patch)
tree04dbc872d2e71d26e323fac4defa9c7896181857
parenta6104286f060400ec17cd028ec14923f6f63195d (diff)
parallelize Copy()
-rw-r--r--src/amun/half/mblas/matrix_functions.h29
1 files changed, 17 insertions, 12 deletions
diff --git a/src/amun/half/mblas/matrix_functions.h b/src/amun/half/mblas/matrix_functions.h
index c3bb4aa5..16b20464 100644
--- a/src/amun/half/mblas/matrix_functions.h
+++ b/src/amun/half/mblas/matrix_functions.h
@@ -94,32 +94,37 @@ void copy(const T *in, size_t count, T *out, cudaMemcpyKind kind) {
template<typename T1, typename T2>
__global__ void gCopy(const VectorWrapper<T1> in, VectorWrapper<T2> out)
{
- for (uint i = 0; i < in.size(); ++i) {
- T2 val = in[i];
- out[i] = val;
+ int id = threadIdx.x + blockIdx.x * blockDim.x;
+
+ if (id < out.size()) {
+ T2 val = in[id];
+ out[id] = val;
}
}
template<typename T1, typename T2>
-void Copy(const T1 *in, size_t count, T2 *out, cudaMemcpyKind kind)
+void Copy(const T1 *in, uint size, T2 *out, cudaMemcpyKind kind)
{
- std::cerr << "Copy1=" << count << std::endl;
+ uint threads = std::min((uint)MAX_THREADS, size);
+ uint blocks = (size / threads) + ((size % threads == 0) ? 0 : 1);
+
+ std::cerr << "Copy1=" << size << std::endl;
if (kind == cudaMemcpyDeviceToHost) {
- const VectorWrapper<T1> inWrap(in, count);
+ const VectorWrapper<T1> inWrap(in, size);
- Vector<T2> d_out(count);
+ Vector<T2> d_out(size);
VectorWrapper<T2> outWrap(d_out);
- gCopy<<<1,1,0, CudaStreamHandler::GetStream()>>>(inWrap, outWrap);
- copy(d_out.data(), count, out, cudaMemcpyDeviceToHost);
+ gCopy<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>>(inWrap, outWrap);
+ copy(d_out.data(), size, out, cudaMemcpyDeviceToHost);
}
else if (kind == cudaMemcpyHostToDevice) {
- Vector<T1> d_in(in, count);
+ Vector<T1> d_in(in, size);
const VectorWrapper<T1> inWrap(d_in);
- VectorWrapper<T2> outWrap(out, count);
+ VectorWrapper<T2> outWrap(out, size);
- gCopy<<<1,1,0, CudaStreamHandler::GetStream()>>>(inWrap, outWrap);
+ gCopy<<<blocks, threads ,0, CudaStreamHandler::GetStream()>>>(inWrap, outWrap);
}
}