Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-16 22:59:12 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-16 22:59:12 +0300
commit327cfc1cc3fbe3ab92927c800aa30aef1c41d517 (patch)
tree9e67c0035ebb7d2c022982443efc1ad338ebc615 /src
parentdd296e77f76143033fc1589c7dce6d12196bbfdd (diff)
pass through backend
Diffstat (limited to 'src')
-rw-r--r--src/graph/expression_graph.cpp6
-rw-r--r--src/graph/node_operators_unary.h24
-rw-r--r--src/graph/parameters.h6
-rw-r--r--src/kernels/sparse.cu6
-rw-r--r--src/kernels/sparse.h33
-rw-r--r--src/kernels/tensor_operators.cu6
-rw-r--r--src/optimizers/optimizers.cu4
-rw-r--r--src/tensors/cpu/dropout.cpp2
-rw-r--r--src/tensors/dispatch.h25
-rw-r--r--src/tensors/gpu/dropout.cu5
-rw-r--r--src/tensors/tensor.cu16
-rw-r--r--src/tensors/tensor.h12
-rw-r--r--src/tensors/tensor_allocator.h10
-rw-r--r--src/training/graph_group_async.cu14
-rw-r--r--src/training/graph_group_async_drop.cu21
-rw-r--r--src/training/graph_group_async_drop.h2
-rw-r--r--src/training/graph_group_multinode.cu15
-rw-r--r--src/training/graph_group_multinode.h2
-rw-r--r--src/training/graph_group_sync.cu4
-rw-r--r--src/training/sparse_tensor.cu24
-rw-r--r--src/training/sparse_tensor.h8
21 files changed, 116 insertions, 129 deletions
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index 183b5787..934e2b73 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -12,15 +12,15 @@ void ExpressionGraph::setDevice(DeviceId deviceId) {
if(!backend_) {
backend_ = BackendByDevice(deviceId, Config::seed);
params_ = New<Parameters>();
- params_->init(backend_->getDevice());
- tensors_ = New<TensorAllocator>(backend_->getDevice());
+ params_->init(backend_);
+ tensors_ = New<TensorAllocator>(backend_);
}
}
Expr ExpressionGraph::dropout(float prob, Shape shape) {
return Expression<ConstantNode>(shared_from_this(),
keywords::init = [prob, this](Tensor t) {
- Dropout(backend_, t, prob);
+ Dropout(t, prob);
},
keywords::shape = shape);
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index a3f27fd2..0170fc73 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp {
const std::string type() { return "tanh"; }
};
-/**
- * Represents a <a
- * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
- * linear</a> node in an expression graph.
- *
- * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and
- * its derivative:
- * \f[
- * f^\prime(x) =
- * \begin{cases}
- * 0 & \text{if } x \leq 0 \\
- * 1 & \text{if } x > 0
- * \end{cases}
- * \f]
- */
+
struct ReLUNodeOp : public UnaryNodeOp {
template <typename... Args>
ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -877,14 +863,14 @@ public:
Tensor& val() {
auto childVal = reshapee_->val();
val_.reset(
- new TensorBase(childVal->memory(), shape(), childVal->getDevice()));
+ new TensorBase(childVal->memory(), shape(), childVal->getBackend()));
return val_;
};
Tensor& grad() {
auto childGrad = reshapee_->grad();
adj_.reset(
- new TensorBase(childGrad->memory(), shape(), childGrad->getDevice()));
+ new TensorBase(childGrad->memory(), shape(), childGrad->getBackend()));
return adj_;
};
@@ -953,7 +939,7 @@ public:
size_t offset = step_ * shape().elements() * sizeof(float);
auto mem = New<MemoryPiece>(childVal->memory()->data() + offset,
childVal->memory()->size());
- val_.reset(new TensorBase(mem, shape(), childVal->getDevice()));
+ val_.reset(new TensorBase(mem, shape(), childVal->getBackend()));
return val_;
};
@@ -962,7 +948,7 @@ public:
size_t offset = step_ * shape().elements() * sizeof(float);
auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset,
childGrad->memory()->size());
- adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice()));
+ adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend()));
return adj_;
};
diff --git a/src/graph/parameters.h b/src/graph/parameters.h
index ed8b7690..3f282e4a 100644
--- a/src/graph/parameters.h
+++ b/src/graph/parameters.h
@@ -20,9 +20,9 @@ private:
Ptr<TensorAllocator> grads_;
public:
- void init(DeviceId deviceId) {
- vals_ = New<TensorAllocator>(deviceId);
- grads_ = New<TensorAllocator>(deviceId);
+ void init(Ptr<Backend> backend) {
+ vals_ = New<TensorAllocator>(backend);
+ grads_ = New<TensorAllocator>(backend);
}
auto begin() -> decltype(params_.begin()) { return params_.begin(); }
diff --git a/src/kernels/sparse.cu b/src/kernels/sparse.cu
index 1d104474..b5080c0c 100644
--- a/src/kernels/sparse.cu
+++ b/src/kernels/sparse.cu
@@ -12,7 +12,7 @@ void multiply(Ptr<CSR> C,
const Ptr<CSR> B,
bool transA,
bool transB) {
- cudaSetDevice(C->getDevice());
+ cudaSetDevice(backend_->getDevice().no);
int nnzTotal;
C->allocRowIndices(A->rows());
CUSPARSE_CHECK(cusparseXcsrgemmNnz(
@@ -91,7 +91,7 @@ void multiply(Ptr<CSR> C,
//}
void LfaForward(Tensor out, Tensor logits, Tensor att, Ptr<CSR> sparseLf) {
- cudaSetDevice(out->getDevice());
+ cudaSetDevice(backend_->getDevice().no);
int batch = att->shape()[0];
int srcWords = att->shape()[2];
@@ -150,7 +150,7 @@ __global__ void gCollapseAtt(float* out,
}
void CollapseAtt(Tensor out, Tensor in) {
- cudaSetDevice(out->getDevice());
+ cudaSetDevice(backend_->getDevice().no);
int nonzeros = out->shape().elements();
int batch = out->shape()[0];
int srcWords = out->shape()[2];
diff --git a/src/kernels/sparse.h b/src/kernels/sparse.h
index d70555f1..cffb398e 100644
--- a/src/kernels/sparse.h
+++ b/src/kernels/sparse.h
@@ -14,7 +14,7 @@ private:
int nnz_{0};
int rows_{0};
int cols_{0};
- DeviceId deviceId_;
+ Ptr<Backend> backend_;
cusparseHandle_t handle_{0};
cusparseMatDescr_t descr_{0};
@@ -24,9 +24,9 @@ private:
float* values_{0};
public:
- CSR(int rows, int cols, DeviceId deviceId)
- : rows_(rows), cols_(cols), deviceId_(deviceId) {
- cudaSetDevice(deviceId_.no);
+ CSR(int rows, int cols, Ptr<Backend> backend)
+ : rows_(rows), cols_(cols), backend_(backend) {
+ cudaSetDevice(backend_->getDevice().no);
CUSPARSE_CHECK(cusparseCreate(&handle_));
CUSPARSE_CHECK(cusparseCreateMatDescr(&descr_));
CUSPARSE_CHECK(cusparseSetMatType(descr_, CUSPARSE_MATRIX_TYPE_GENERAL));
@@ -38,9 +38,9 @@ public:
const std::vector<float>& values,
const std::vector<int>& rowIndices,
const std::vector<int>& colIndices,
- DeviceId deviceId)
- : nnz_(values.size()), rows_(rows), cols_(cols), deviceId_(deviceId) {
- cudaSetDevice(deviceId_.no);
+ Ptr<Backend> backend)
+ : nnz_(values.size()), rows_(rows), cols_(cols), backend_(backend) {
+ cudaSetDevice(backend_->getDevice().no);
CUSPARSE_CHECK(cusparseCreate(&handle_));
CUSPARSE_CHECK(cusparseCreateMatDescr(&descr_));
CUSPARSE_CHECK(cusparseSetMatType(descr_, CUSPARSE_MATRIX_TYPE_GENERAL));
@@ -73,8 +73,8 @@ public:
CUDA_CHECK(cudaFree(cooRowIndices));
}
- CSR(Tensor dense) : deviceId_(dense->getDevice()) {
- cudaSetDevice(deviceId_.no);
+ CSR(Tensor dense) : backend_(dense->getBackend()) {
+ cudaSetDevice(backend_->getDevice().no);
rows_ = dense->shape()[0] * dense->shape()[2] * dense->shape()[3];
cols_ = dense->shape()[1];
@@ -114,7 +114,7 @@ public:
}
~CSR() {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
if(values_)
CUDA_CHECK(cudaFree(values_));
if(rowIndices_)
@@ -129,7 +129,7 @@ public:
}
void toTensor(Tensor dense) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
ABORT_IF(dense->size() != rows_ * cols_, "Matrix sizes do not match");
cusparseScsc2dense(handle_,
@@ -154,10 +154,10 @@ public:
int* rowIndices() { return rowIndices_; }
int* colIndices() { return colIndices_; }
- DeviceId getDevice() { return deviceId_; }
+ DeviceId getDevice() { return backend_->getDevice(); }
void allocValues(int nnz = 0) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
if(nnz > 0)
nnz_ = nnz;
if(values_)
@@ -166,7 +166,7 @@ public:
}
void allocRowIndices(int rows = 0) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
if(rows > 0)
rows_ = rows;
if(rowIndices_)
@@ -175,7 +175,7 @@ public:
}
void allocColIndices(int nnz = 0) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
if(nnz > 0)
nnz_ = nnz;
if(colIndices_)
@@ -184,11 +184,12 @@ public:
}
std::string debug() {
+ cudaSetDevice(backend_->getDevice().no);
uint8_t* buffer;
CUDA_CHECK(cudaMalloc(&buffer, sizeof(float) * rows() * cols()));
auto mem = New<MemoryPiece>(buffer, sizeof(float) * rows() * cols());
- Tensor tensor(new TensorBase(mem, {rows(), cols()}, deviceId_));
+ Tensor tensor(new TensorBase(mem, {rows(), cols()}, backend_));
toTensor(tensor);
std::string temp = tensor->debug();
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index 69b8afc4..87019d9e 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -893,7 +893,7 @@ void Select(Ptr<Allocator> allocator,
auto mp_indices = allocator->alloc<size_t>(indices.size());
CudaCopy(indices.data(), indices.data() + indices.size(), mp_indices->data());
-
+
int axisGPU = axis + gpu::Shape::size() - out->shape().size();
gSelect<<<blocks, threads>>>(out->data(),
out->shape(),
@@ -919,7 +919,7 @@ void Insert(Ptr<Allocator> allocator,
auto mp_indices = allocator->alloc<size_t>(indices.size());
CudaCopy(indices.data(), indices.data() + indices.size(), mp_indices->data());
-
+
int axisGPU = axis + gpu::Shape::size() - out->shape().size();
gInsert<<<blocks, threads>>>(out->data(),
out->shape(),
@@ -1295,7 +1295,7 @@ float L2Norm(Tensor in) {
uint8_t* data;
cudaMalloc(&data, blocks * sizeof(float));
Tensor out(new TensorBase(
- New<MemoryPiece>(data, blocks * sizeof(float)), {1, blocks}, in->getDevice()));
+ New<MemoryPiece>(data, blocks * sizeof(float)), {1, blocks}, in->getBackend()));
ReduceAll(_1 * _1, out, in);
float dataCpu = sqrtf(out->get(0));
diff --git a/src/optimizers/optimizers.cu b/src/optimizers/optimizers.cu
index 2874e1d5..e82800c9 100644
--- a/src/optimizers/optimizers.cu
+++ b/src/optimizers/optimizers.cu
@@ -13,7 +13,7 @@ void Sgd::updateImpl(Tensor params, Tensor grads) {
void Adagrad::updateImpl(Tensor params, Tensor grads) {
if(!alloc_)
- alloc_ = New<TensorAllocator>(params->getDevice());
+ alloc_ = New<TensorAllocator>(params->getBackend());
if(!gt_) {
int elements = params->size();
@@ -42,7 +42,7 @@ void Adagrad::resetStats() {
void Adam::updateImpl(Tensor params, Tensor grads) {
if(!alloc_)
- alloc_ = New<TensorAllocator>(params->getDevice());
+ alloc_ = New<TensorAllocator>(params->getBackend());
if(!mt_) {
int elements = params->size();
diff --git a/src/tensors/cpu/dropout.cpp b/src/tensors/cpu/dropout.cpp
index 4286042b..cc6cea41 100644
--- a/src/tensors/cpu/dropout.cpp
+++ b/src/tensors/cpu/dropout.cpp
@@ -5,7 +5,7 @@
namespace marian {
namespace cpu {
- void Dropout(Ptr<marian::Backend> backend, Tensor tensor, float p) {
+ void Dropout(Tensor tensor, float p) {
ABORT("Not implemented");
std::fill(tensor->data(), tensor->data() + tensor->size(), p);
}
diff --git a/src/tensors/dispatch.h b/src/tensors/dispatch.h
index 14b74f5a..e63a6af1 100644
--- a/src/tensors/dispatch.h
+++ b/src/tensors/dispatch.h
@@ -1,35 +1,34 @@
#pragma once
#include "common/definitions.h"
-#include "tensors/backend.h"
#include "tensors/tensor.h"
#define DISPATCH1(Function, Arg1) \
namespace gpu { \
- void Function(Ptr<marian::Backend>, Arg1); \
+ void Function(Arg1); \
} \
namespace cpu { \
- void Function(Ptr<marian::Backend>, Arg1); \
+ void Function(Arg1); \
} \
- void Function(Ptr<marian::Backend> backend, Arg1 arg1) { \
- if(backend->getDevice().type == DeviceType::gpu) \
- gpu::Function(backend, arg1); \
+ void Function(Arg1 arg1) { \
+ if(arg1->getBackend()->getDevice().type == DeviceType::gpu) \
+ gpu::Function(arg1); \
else \
- cpu::Function(backend, arg1); \
+ cpu::Function(arg1); \
}
#define DISPATCH2(Function, Arg1, Arg2) \
namespace gpu { \
- void Function(Ptr<marian::Backend>, Arg1, Arg2); \
+ void Function(Arg1, Arg2); \
} \
namespace cpu { \
- void Function(Ptr<marian::Backend>, Arg1, Arg2); \
+ void Function(Arg1, Arg2); \
} \
- static inline void Function(Ptr<marian::Backend> backend, Arg1 arg1, Arg2 arg2) { \
- if(backend->getDevice().type == DeviceType::gpu) \
- gpu::Function(backend, arg1, arg2); \
+ static inline void Function(Arg1 arg1, Arg2 arg2) { \
+ if(arg1->getBackend()->getDevice().type == DeviceType::gpu) \
+ gpu::Function(arg1, arg2); \
else \
- cpu::Function(backend, arg1, arg2); \
+ cpu::Function(arg1, arg2); \
}
namespace marian {
diff --git a/src/tensors/gpu/dropout.cu b/src/tensors/gpu/dropout.cu
index 6dc49c51..4a4223a8 100644
--- a/src/tensors/gpu/dropout.cu
+++ b/src/tensors/gpu/dropout.cu
@@ -35,8 +35,9 @@ namespace marian {
}
}
- void Dropout(Ptr<marian::Backend> backend, Tensor tensor, float p) {
- curandGenerator_t gen = std::static_pointer_cast<gpu::Backend>(backend)->getCurandGenerator();
+ void Dropout(Tensor tensor, float p) {
+ auto gpuBackend = std::static_pointer_cast<gpu::Backend>(tensor->getBackend());
+ curandGenerator_t gen = gpuBackend->getCurandGenerator();
int n = tensor->size();
CURAND_CALL(curandGenerateUniform(gen, tensor->data(), n));
diff --git a/src/tensors/tensor.cu b/src/tensors/tensor.cu
index bc26fcec..96d979bf 100644
--- a/src/tensors/tensor.cu
+++ b/src/tensors/tensor.cu
@@ -16,7 +16,7 @@ __global__ void gFill(float *d_in, int size, float val) {
}
float TensorBase::get(size_t i) {
- cudaSetDevice(deviceId_.no);
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
float temp;
CUDA_CHECK(
cudaMemcpy(&temp, data() + i, sizeof(float), cudaMemcpyDeviceToHost));
@@ -25,14 +25,14 @@ float TensorBase::get(size_t i) {
}
void TensorBase::set(size_t i, float value) {
- cudaSetDevice(deviceId_.no);
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
CUDA_CHECK(
cudaMemcpy(data() + i, &value, sizeof(float), cudaMemcpyHostToDevice));
cudaStreamSynchronize(0);
}
void TensorBase::get(std::vector<float> &v) {
- CUDA_CHECK(cudaSetDevice(deviceId_.no));
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
v.resize(size());
CUDA_CHECK(cudaMemcpy(
v.data(), data(), size() * sizeof(float), cudaMemcpyDeviceToHost));
@@ -40,7 +40,7 @@ void TensorBase::get(std::vector<float> &v) {
}
void TensorBase::set(float value) {
- cudaSetDevice(deviceId_.no);
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
int threads = std::min(512, (int)size());
int blocks = (size() / threads) + (size() % threads != 0);
gFill<<<blocks, threads>>>(data(), size(), value);
@@ -48,7 +48,7 @@ void TensorBase::set(float value) {
}
void TensorBase::set(const std::vector<float> &v) {
- CUDA_CHECK(cudaSetDevice(deviceId_.no));
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
CUDA_CHECK(cudaMemcpy(
data(), v.data(), v.size() * sizeof(float), cudaMemcpyHostToDevice));
cudaStreamSynchronize(0);
@@ -56,13 +56,13 @@ void TensorBase::set(const std::vector<float> &v) {
void TensorBase::setSparse(const std::vector<size_t> &k,
const std::vector<float> &v) {
- cudaSetDevice(deviceId_.no);
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
SetSparse(data(), k, v);
cudaStreamSynchronize(0);
}
void TensorBase::copyFrom(Tensor in) {
- cudaSetDevice(deviceId_.no);
+ CUDA_CHECK(cudaSetDevice(backend_->getDevice().no));
CUDA_CHECK(cudaMemcpy(data(),
(float *)in->data(),
in->size() * sizeof(float),
@@ -74,7 +74,7 @@ std::string TensorBase::debug() {
std::stringstream strm;
assert(shape_.size());
strm << shape_;
- strm << " device=" << deviceId_;
+ strm << " device=" << backend_->getDevice();
strm << " ptr=" << (size_t)memory_->data();
strm << " bytes=" << memory_->size();
strm << std::endl;
diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h
index 15d3a427..309cacd5 100644
--- a/src/tensors/tensor.h
+++ b/src/tensors/tensor.h
@@ -9,6 +9,7 @@
#include "common/definitions.h"
#include "common/shape.h"
#include "tensors/memory_piece.h"
+#include "tensors/backend.h"
namespace marian {
@@ -16,11 +17,11 @@ class TensorBase : public std::enable_shared_from_this<TensorBase> {
private:
Ptr<MemoryPiece> memory_;
Shape shape_;
- DeviceId deviceId_;
+ Ptr<Backend> backend_;
public:
- TensorBase(Ptr<MemoryPiece> memory, Shape shape, DeviceId deviceId)
- : memory_(memory), shape_(shape), deviceId_(deviceId) {}
+ TensorBase(Ptr<MemoryPiece> memory, Shape shape, Ptr<Backend> backend)
+ : memory_(memory), shape_(shape), backend_(backend) {}
~TensorBase() {}
@@ -39,12 +40,13 @@ public:
return get(0);
}
- DeviceId getDevice() { return deviceId_; }
+ Ptr<Backend> getBackend() { return backend_; }
+ DeviceId getDevice() { return backend_->getDevice(); }
Tensor subtensor(int offset, int size) {
auto mem = New<MemoryPiece>(memory_->data() + sizeof(float) * offset,
sizeof(float) * size);
- return Tensor(new TensorBase(mem, {1, size}, deviceId_));
+ return New<TensorBase>(mem, Shape{1, size}, backend_);
}
float get(size_t i);
diff --git a/src/tensors/tensor_allocator.h b/src/tensors/tensor_allocator.h
index e1c54b22..18aae134 100644
--- a/src/tensors/tensor_allocator.h
+++ b/src/tensors/tensor_allocator.h
@@ -16,11 +16,13 @@ private:
const size_t GROW = CHUNK * MBYTE;
const size_t ALIGN = 256;
+ Ptr<Backend> backend_;
Ptr<Allocator> allocator_;
public:
- TensorAllocator(DeviceId deviceId)
- : allocator_(New<Allocator>(deviceId, 0, GROW, ALIGN)) {}
+ TensorAllocator(Ptr<Backend> backend)
+ : backend_(backend),
+ allocator_(New<Allocator>(backend_->getDevice(), 0, GROW, ALIGN)) {}
~TensorAllocator() { clear(); }
@@ -58,7 +60,7 @@ public:
if(!t || t->shape() != shape) {
int size = shape.elements();
auto mem = allocator_->alloc<float>(size);
- t = Tensor(new TensorBase(mem, shape, allocator_->getDevice()));
+ t = Tensor(new TensorBase(mem, shape, backend_));
}
}
@@ -67,7 +69,7 @@ public:
Tensor asTensor() {
auto mem = allocator_->memory();
int size = mem->size() / sizeof(float);
- return Tensor(new TensorBase(mem, {1, size}, allocator_->getDevice()));
+ return Tensor(new TensorBase(mem, {1, size}, backend_));
}
size_t size() { return allocator_->size() / sizeof(float); }
diff --git a/src/training/graph_group_async.cu b/src/training/graph_group_async.cu
index 90d575bd..18f3908a 100644
--- a/src/training/graph_group_async.cu
+++ b/src/training/graph_group_async.cu
@@ -93,12 +93,12 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) {
int pos = 0;
// parameter sharding
- for(auto device : devices_) {
+ for(auto graph : graphs_) {
int __size__ = min(shardSize_, totalSize);
totalSize -= __size__;
Tensor param;
- Ptr<TensorAllocator> allocator = New<TensorAllocator>(DeviceId{device, DeviceType::gpu});
+ Ptr<TensorAllocator> allocator = New<TensorAllocator>(graph->getBackend());
allocator->reserveExact(__size__ * sizeof(float));
allocator->allocate(param, {1, __size__});
paramsAlloc_.push_back(allocator);
@@ -112,11 +112,11 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) {
if(grads_.size() == 0) {
int totalSize = graphs_[0]->params()->vals()->size();
- for(auto device : devices_) {
+ for(auto graph : graphs_) {
int __size__ = min(shardSize_, totalSize);
totalSize -= __size__;
Tensor grad_;
- Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(DeviceId{device, DeviceType::gpu});
+ Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(graph->getBackend());
allocator_->reserveExact(__size__ * sizeof(float));
allocator_->allocate(grad_, {1, __size__});
@@ -129,11 +129,11 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) {
int totalSize = graphs_[0]->params()->vals()->size();
int i = 0;
- for(auto device : devices_) {
+ for(auto graph : graphs_) {
int __size__ = min(shardSize_, totalSize);
totalSize -= __size__;
Tensor paramAvg;
- Ptr<TensorAllocator> allocator = New<TensorAllocator>(DeviceId{device, DeviceType::gpu});
+ Ptr<TensorAllocator> allocator = New<TensorAllocator>(graph->getBackend());
allocator->reserveExact(__size__ * sizeof(float));
allocator->allocate(paramAvg, {1, __size__});
@@ -187,7 +187,7 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
Tensor gradients;
if(tau_ > 1) {
if(t == 0) {
- accAlloc = New<TensorAllocator>(graph->getDevice());
+ accAlloc = New<TensorAllocator>(graph->getBackend());
accAlloc->reserveExact(graph->params()->grads()->memory()->size());
accAlloc->allocate(accGradients, graph->params()->grads()->shape());
accGradients->set(0);
diff --git a/src/training/graph_group_async_drop.cu b/src/training/graph_group_async_drop.cu
index 84985009..33684bf9 100644
--- a/src/training/graph_group_async_drop.cu
+++ b/src/training/graph_group_async_drop.cu
@@ -8,9 +8,9 @@
namespace marian {
-Tensor AsyncGraphGroupDrop::newTensor(int size, DeviceId deviceId) {
+Tensor AsyncGraphGroupDrop::newTensor(int size, Ptr<Backend> backend) {
Tensor t;
- Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(deviceId);
+ Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(backend);
allocator_->reserveExact(size * sizeof(float));
allocator_->allocate(t, {1, size});
allocators.push_back(allocator_);
@@ -86,7 +86,7 @@ void AsyncGraphGroupDrop::pushGradients(Tensor newGrads,
// get the sparse gradient
pushDropper_[device_id]->dropGraph(
- newGrads, pushSparseGradient_[device_id],
+ newGrads, pushSparseGradient_[device_id],
droping_rate, dropping_momentum);
SparseTensor newSparseGrads = pushSparseGradient_[device_id];
@@ -146,13 +146,12 @@ void AsyncGraphGroupDrop::init(Ptr<data::Batch> batch) {
fetchStep_.push_back(0);
pushStep_.push_back(0);
- size_t device = devices_[i];
// temporary tensor to compute parameter delta before fetching
- paramsDelta_.push_back(newTensor(shardSize, {device, DeviceType::gpu}));
+ paramsDelta_.push_back(newTensor(shardSize, graphs_[i]->getBackend()));
// tensors to store local params history
for(int h_id = 0; h_id < devices_.size(); h_id++) {
- Tensor tmp = newTensor(params_[i]->size(), {device, DeviceType::gpu});
+ Tensor tmp = newTensor(params_[i]->size(), graphs_[i]->getBackend());
tmp->copyFrom(params_[i]);
paramsLocal_[h_id].push_back(tmp);
}
@@ -162,23 +161,23 @@ void AsyncGraphGroupDrop::init(Ptr<data::Batch> batch) {
// N-dropper for fetch
std::vector<GradientDrop> tmpDropper;
- for(int i = 0; i < devices_.size(); i++)
+ for(auto device : devices_)
tmpDropper.push_back(GradientDrop(new GradientDropBase()));
fetchDropper.push_back(tmpDropper);
// sparsetensor to store sparsified gradients per-device
pushSparseGradient_.push_back(
- SparseTensor(new SparseTensorBase(sparseCap, {device, DeviceType::gpu})));
+ SparseTensor(new SparseTensorBase(sparseCap, graphs_[i]->getBackend())));
pushShardedSparseGradient_.push_back(
- SparseTensor(new SparseTensorBase(sparseCap, {device, DeviceType::gpu})));
+ SparseTensor(new SparseTensorBase(sparseCap, graphs_[i]->getBackend())));
fetchSparseGradient_.push_back(SparseTensor(
- new SparseTensorBase(sparseCap / devices_.size(), {device, DeviceType::gpu})));
+ new SparseTensorBase(sparseCap / devices_.size(), graphs_[i]->getBackend())));
std::vector<SparseTensor> tmp;
for(int i = 0; i < devices_.size(); i++)
tmp.push_back(SparseTensor(
- new SparseTensorBase(sparseCap / devices_.size(), {device, DeviceType::gpu})));
+ new SparseTensorBase(sparseCap / devices_.size(), graphs_[i]->getBackend())));
fetchShardedSparseGradient_.push_back(tmp);
}
diff --git a/src/training/graph_group_async_drop.h b/src/training/graph_group_async_drop.h
index 33e74a40..f32d9444 100644
--- a/src/training/graph_group_async_drop.h
+++ b/src/training/graph_group_async_drop.h
@@ -31,7 +31,7 @@ class AsyncGraphGroupDrop : public AsyncGraphGroup {
std::vector<Ptr<TensorAllocator>> allocators;
- Tensor newTensor(int size, DeviceId deviceId);
+ Tensor newTensor(int size, Ptr<Backend> backend);
protected:
void init(Ptr<data::Batch> batch);
diff --git a/src/training/graph_group_multinode.cu b/src/training/graph_group_multinode.cu
index 78cd842f..34aa2b5d 100644
--- a/src/training/graph_group_multinode.cu
+++ b/src/training/graph_group_multinode.cu
@@ -19,9 +19,9 @@ void MultiNodeGraphGroup::setScheduler(Ptr<Scheduler> scheduler) {
/**
* Allocate new tensor on given GPU and store allocator.
*/
-Tensor MultiNodeGraphGroup::newTensor(int size, DeviceId deviceId) {
+Tensor MultiNodeGraphGroup::newTensor(int size, Ptr<Backend> backend) {
Tensor t;
- Ptr<TensorAllocator> allocator = New<TensorAllocator>(deviceId);
+ Ptr<TensorAllocator> allocator = New<TensorAllocator>(backend);
allocator->reserveExact(size * sizeof(float));
allocator->allocate(t, {1, size});
allocators_.push_back(allocator);
@@ -148,14 +148,12 @@ void MultiNodeGraphGroup::initClientCommOverlapVars() {
void MultiNodeGraphGroup::initClientCommOverlapGpuTensors() {
size_t modelSize = clientGraphs_[0]->params()->vals()->size();
for(int client = 0; client < devices_.size(); client++) {
- DeviceId deviceId{devices_[client], DeviceType::gpu};
-
// Communication overlap buffer (for grads + params)
- Tensor commOverlapBuffer = newTensor(modelSize, deviceId);
+ Tensor commOverlapBuffer = newTensor(modelSize, clientGraphs_[client]->getBackend());
commOverlapBuffer->copyFrom(clientGraphs_[0]->params()->vals());
clientCommOverlapBuffersGPU_.push_back(commOverlapBuffer);
// Gradients local sum buffer
- Tensor sumGrads = newTensor(modelSize, deviceId);
+ Tensor sumGrads = newTensor(modelSize, clientGraphs_[client]->getBackend());
sumGrads->set(0);
clientSummedGradsGPU.push_back(sumGrads);
// Local optimizer to apply summed gradients
@@ -207,12 +205,11 @@ void MultiNodeGraphGroup::calculateShardSizes() {
void MultiNodeGraphGroup::initShardGpuTensors() {
size_t offset = 0;
for(int shard = 0; shard < devices_.size(); shard++) {
- DeviceId deviceId{devices_[shard], DeviceType::gpu};
- Tensor gpuParams = newTensor(shardSizes_[shard], deviceId);
+ Tensor gpuParams = newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend());
gpuParams->copyFrom(clientGraphs_[0]->params()->vals()->subtensor(
offset, shardSizes_[shard]));
shardParams_.push_back(gpuParams);
- shardGrads_.push_back(newTensor(shardSizes_[shard], deviceId));
+ shardGrads_.push_back(newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend()));
}
}
diff --git a/src/training/graph_group_multinode.h b/src/training/graph_group_multinode.h
index 90243f7c..c6dc495c 100644
--- a/src/training/graph_group_multinode.h
+++ b/src/training/graph_group_multinode.h
@@ -217,7 +217,7 @@ protected:
/**
* Allocate new tensor on given GPU and store allocator.
*/
- Tensor newTensor(int size, DeviceId deviceId);
+ Tensor newTensor(int size, Ptr<Backend> backend);
/**
* Setup training environment and launch server thread and (if enabled) client
diff --git a/src/training/graph_group_sync.cu b/src/training/graph_group_sync.cu
index e713b99d..171c0652 100644
--- a/src/training/graph_group_sync.cu
+++ b/src/training/graph_group_sync.cu
@@ -60,7 +60,7 @@ void SyncGraphGroup::execute(Ptr<data::Batch> batch) {
for(auto graph : graphs_) {
int __size__ = min(shardSize_, totalSize);
- auto paramsAlloc = New<TensorAllocator>(graph->getDevice());
+ auto paramsAlloc = New<TensorAllocator>(graph->getBackend());
paramsAllocs_.push_back(paramsAlloc);
paramsAlloc->reserveExact(3 * __size__ * sizeof(float));
@@ -87,7 +87,7 @@ void SyncGraphGroup::execute(Ptr<data::Batch> batch) {
int __size__ = min(shardSize_, totalSize);
totalSize -= __size__;
Tensor paramAvg;
- auto allocator = New<TensorAllocator>(graph->getDevice());
+ auto allocator = New<TensorAllocator>(graph->getBackend());
allocator->reserveExact(__size__ * sizeof(float));
allocator->allocate(paramAvg, {1, __size__});
diff --git a/src/training/sparse_tensor.cu b/src/training/sparse_tensor.cu
index 5d32b16f..aafafa97 100644
--- a/src/training/sparse_tensor.cu
+++ b/src/training/sparse_tensor.cu
@@ -45,9 +45,9 @@ __global__ void gFindSubtensor(int* indices,
resultEnd[0] = idx;
}
-SparseTensorBase::SparseTensorBase(int capacity, DeviceId deviceId)
-: deviceId_(deviceId), capacity_(capacity) {
- cudaSetDevice(deviceId_.no);
+SparseTensorBase::SparseTensorBase(int capacity, Ptr<Backend> backend)
+: backend_(backend), capacity_(capacity) {
+ cudaSetDevice(backend_->getDevice().no);
CUDA_CHECK(cudaMalloc(&data_, sizeof(float) * capacity));
CUDA_CHECK(cudaMalloc(&indices_, sizeof(int) * capacity));
@@ -58,8 +58,8 @@ SparseTensorBase::SparseTensorBase(int capacity, DeviceId deviceId)
SparseTensorBase::SparseTensorBase(float* data,
int* indices,
int size,
- DeviceId deviceId)
-: deviceId_(deviceId) {
+ Ptr<Backend> backend)
+: backend_(backend) {
data_ = data;
indices_ = indices;
size_ = size;
@@ -93,7 +93,7 @@ void SparseTensorBase::copyFrom(float* data,
size_ = size;
if(size == 0)
return;
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
cudaMemcpy(data_, data, size * sizeof(float), cudaMemcpyDefault);
if(!data_only)
@@ -107,8 +107,8 @@ void SparseTensorBase::copyFrom(std::shared_ptr<SparseTensorBase> t,
copyFrom(t->data(), t->indices(), t->size(), data_only);
}
-DeviceId SparseTensorBase::getDevice() {
- return deviceId_;
+Ptr<Backend> SparseTensorBase::getBackend() {
+ return backend_;
}
void SparseTensorBase::setSize(int size) {
@@ -117,7 +117,7 @@ void SparseTensorBase::setSize(int size) {
// return the dense representation of this tensor
void SparseTensorBase::toDense(Tensor t, int offset) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
int threads = 512;
int blocks = 1 + size_ / threads;
t->set(0);
@@ -127,7 +127,7 @@ void SparseTensorBase::toDense(Tensor t, int offset) {
}
void SparseTensorBase::scatterAdd(Tensor t, int offset) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
cudaStreamSynchronize(0);
int threads = 512;
int blocks = 1 + size_ / threads;
@@ -139,7 +139,7 @@ void SparseTensorBase::scatterAdd(Tensor t, int offset) {
std::shared_ptr<SparseTensorBase> SparseTensorBase::subtensor(int pos,
int size,
int idx) {
- cudaSetDevice(deviceId_.no);
+ cudaSetDevice(backend_->getDevice().no);
cudaStreamSynchronize(0);
int* start = gstart_ + idx;
int* end = gend_ + idx;
@@ -165,6 +165,6 @@ std::shared_ptr<SparseTensorBase> SparseTensorBase::subtensor(int pos,
int subtensorSize = std::max(0, endOffset - startOffset + 1);
cudaStreamSynchronize(0);
return std::shared_ptr<SparseTensorBase>(new SparseTensorBase(
- data_ + startOffset, indices_ + startOffset, subtensorSize, deviceId_));
+ data_ + startOffset, indices_ + startOffset, subtensorSize, backend_));
}
}
diff --git a/src/training/sparse_tensor.h b/src/training/sparse_tensor.h
index 03fb53a9..9194748f 100644
--- a/src/training/sparse_tensor.h
+++ b/src/training/sparse_tensor.h
@@ -10,15 +10,15 @@ class SparseTensorBase : public std::enable_shared_from_this<SparseTensorBase> {
int* indices_;
int size_;
int capacity_;
- DeviceId deviceId_;
+ Ptr<Backend> backend_;
int* d_is_unsorted;
int* gstart_;
int* gend_;
public:
- SparseTensorBase(int capacity, DeviceId deviceId);
- SparseTensorBase(float* data, int* indices, int size, DeviceId deviceId);
+ SparseTensorBase(int capacity, Ptr<Backend> backend);
+ SparseTensorBase(float* data, int* indices, int size, Ptr<Backend> backend);
~SparseTensorBase() {}
@@ -43,7 +43,7 @@ public:
void scatterAdd(Tensor t, int offset = 0);
std::shared_ptr<SparseTensorBase> subtensor(int pos, int size, int idx);
- DeviceId getDevice();
+ Ptr<Backend> getBackend();
void toDense(Tensor t, int offset);
};