Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-06 08:10:12 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-06 08:10:12 +0300
commit1e0f0a47c9cdf5c916c08b8c35a45cfdbb1d1c9c (patch)
tree4eda1d4fe86ee3c1589168eae10153050ad6f98e
parentd96c03f7cb599f4b1d7ccdfe05e13085242503c9 (diff)
parent7ce3873a315ad939e48147d525ba4de4dcb7fb8a (diff)
Merge branch 'mergeWithPublic' of ssh://vs-ssh.visualstudio.com:22/DefaultCollection/Marian/_ssh/marian-dev into mergeWithPublic
-rw-r--r--src/graph/node_operators_unary.h2
-rw-r--r--src/tensors/cpu/tensor_operators.cpp3
-rw-r--r--src/tensors/gpu/tensor_operators.cu16
-rw-r--r--src/tensors/tensor_operators.h2
4 files changed, 11 insertions, 12 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index cbeded24..b51c8a75 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -657,7 +657,7 @@ struct RowsNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyRows(val_, child(0)->val(), indices_))};
+ return {NodeOp(CopyRows(val_, child(0)->val(), indices_, graph()->allocator()))};
}
NodeOps backwardOps() {
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index 4d5d40dc..b82600e1 100644
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -381,7 +381,8 @@ void LogSoftmaxGrad(Tensor grad_, Tensor adj_, Tensor val_) {
void CopyRows(Tensor out_,
const Tensor in_,
- const std::vector<size_t>& indices) {
+ const std::vector<size_t>& indices,
+ Ptr<Allocator> allocator) {
size_t cols = in_->shape()[1];
size_t rows = indices.size();
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index 4f11f9f8..6eb5a8dd 100644
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -735,7 +735,7 @@ __global__ void gCopyRows(float* out,
}
}
-void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices) {
+void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices, Ptr<Allocator> allocator) {
cudaSetDevice(out->getDevice().no);
size_t cols = in->shape().back();
@@ -744,17 +744,15 @@ void CopyRows(Tensor out, const Tensor in, const std::vector<size_t>& indices) {
int threads = std::min(MAX_THREADS, (int)cols);
int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy);
- size_t* d_indices;
- CUDA_CHECK(cudaMalloc(&d_indices, rowsToCopy * sizeof(size_t)));
- CUDA_CHECK(cudaMemcpy(d_indices,
- indices.data(),
- rowsToCopy * sizeof(size_t),
- cudaMemcpyHostToDevice));
+ auto mp_indices = allocator->alloc<size_t>(rowsToCopy);
+ CudaCopy(indices.data(),
+ indices.data() + indices.size(),
+ mp_indices->data<size_t>());
gCopyRows<<<blocks, threads>>>(
- out->data(), in->data(), cols, d_indices, rowsToCopy);
+ out->data(), in->data(), cols, mp_indices->data<size_t>(), rowsToCopy);
- CUDA_CHECK(cudaFree(d_indices));
+ allocator->free(mp_indices);
}
__global__ void gPasteRows(float* out,
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 308b1d67..78f61888 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -133,7 +133,7 @@ static inline void Deconcatenate(std::vector<marian::Tensor>& outputs,
DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
- DISPATCH3(CopyRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&)
+ DISPATCH4(CopyRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&, Ptr<Allocator>)
DISPATCH3(PasteRows, marian::Tensor, const marian::Tensor, const std::vector<size_t>&)
DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const std::vector<size_t>&)
DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const std::vector<size_t>&)