diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-08-04 10:27:15 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-04 10:27:15 +0300 |
commit | 845f71786bc5c21ee5d1a131eb1bbfc48a793371 (patch) | |
tree | 1d663c4471fe890a354595b563525396b08fd20a | |
parent | 78cc62e1beddc95b96902467f1e9fccf9db590a3 (diff) | |
parent | 3e19457912a5310c3f7d566b668a23f1b9f92c3c (diff) |
Merge pull request #273 from wenyong-h/opt-copyrows
avoid cudaMalloc in CopyRows
-rw-r--r-- | src/graph/node_operators_unary.h | 2 | ||||
-rw-r--r-- | src/tensors/cpu/tensor_operators.cpp | 3 | ||||
-rw-r--r-- | src/tensors/gpu/tensor_operators.cu | 16 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 2 |
4 files changed, 11 insertions, 12 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index cbeded24..b51c8a75 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -657,7 +657,7 @@ struct RowsNodeOp : public UnaryNodeOp { NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyRows(val_, child(0)->val(), indices_))}; + return {NodeOp(CopyRows(val_, child(0)->val(), indices_, graph()->allocator()))}; } NodeOps backwardOps() { diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 4d5d40dc..b82600e1 100644 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -381,7 +381,8 @@ void LogSoftmaxGrad(Tensor grad_, Tensor adj_, Tensor val_) { void CopyRows(Tensor out_, const Tensor in_, - const std::vector<size_t>& indices) { + const std::vector<size_t>& indices, + Ptr<Allocator> allocator) { size_t cols = in_->shape()[1]; size_t rows = indices.size(); diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index 4f11f9f8..6eb5a8dd 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -735,7 +735,7 @@ __global__ void gCopyRows(float* out, } } -void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices) { +void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices, Ptr<Allocator> allocator) { cudaSetDevice(out->getDevice().no); size_t cols = in->shape().back(); @@ -744,17 +744,15 @@ void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices) { int threads = std::min(MAX_THREADS, (int)cols); int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy); - size_t* d_indices; - CUDA_CHECK(cudaMalloc(&d_indices, rowsToCopy * sizeof(size_t))); - CUDA_CHECK(cudaMemcpy(d_indices, - indices.data(), - rowsToCopy * sizeof(size_t), - cudaMemcpyHostToDevice)); + auto mp_indices = allocator->alloc<size_t>(rowsToCopy); + CudaCopy(indices.data(), + indices.data() + indices.size(), + mp_indices->data<size_t>()); gCopyRows<<<blocks, threads>>>( - out->data(), in->data(), cols, d_indices, rowsToCopy); + out->data(), in->data(), cols, mp_indices->data<size_t>(), rowsToCopy); - CUDA_CHECK(cudaFree(d_indices)); + allocator->free(mp_indices); } __global__ void gPasteRows(float* out, diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 87f4f27d..573c5c0e 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -119,7 +119,7 @@ static inline void Deconcatenate(std::vector<marian::Tensor>& outputs, DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor) DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor) - DISPATCH3(CopyRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&) + DISPATCH4(CopyRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&, Ptr<Allocator>) DISPATCH3(PasteRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&) DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const std::vector<size_t>&) DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const std::vector<size_t>&) |