Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2018-12-06 08:01:55 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2018-12-06 08:01:55 +0300
commitefc3bfa4dc23dd87b8ebd9ed9ead1664eb28108b (patch)
tree3e0a2deb12a1b49fe56a6f6acbe5c53a76639153
parent902dd122a991e2eab7a7fb7bb605a973e04e8ec3 (diff)
parent1714781e65d8836e7423abbbdcff22a4191fc405 (diff)
Merge branch 'master' of https://github.com/marian-nmt/marian-dev
-rw-r--r--CHANGELOG.md1
-rw-r--r--VERSION2
-rwxr-xr-xsrc/tensors/gpu/algorithm.cu43
-rw-r--r--src/tensors/gpu/algorithm.h3
-rwxr-xr-xsrc/tensors/tensor.h116
-rwxr-xr-xsrc/training/communicator.h25
6 files changed, 140 insertions, 50 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a2c2e48d..4bc326f2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Fixed
- Errors due to warnings
+- Fixed issue concerning failed saving with single GPU training and --sync-sgd option.
### Changed
- Set nearly all warnings as errors for Marian's own targets. Disable warnings for 3rd party.
diff --git a/VERSION b/VERSION
index a97fc441..3b34d229 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-v1.7.1
+v1.7.2
diff --git a/src/tensors/gpu/algorithm.cu b/src/tensors/gpu/algorithm.cu
index 9aad629d..bdf66bac 100755
--- a/src/tensors/gpu/algorithm.cu
+++ b/src/tensors/gpu/algorithm.cu
@@ -49,7 +49,7 @@ void fill(Ptr<Backend> backend, T* begin, T* end, T value) {
if (size == 0)
return;
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
- int threadsPerBlock = std::min(512, size);
+ int threadsPerBlock = std::min(MAX_THREADS, size);
int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)
gFill<<<blocks, threadsPerBlock>>>(begin, size, value);
CUDA_CHECK(cudaStreamSynchronize(0));
@@ -76,5 +76,46 @@ void setSparse(Ptr<Backend> backend,
// gpu::SetSparse(data, keys, values);
CUDA_CHECK(cudaStreamSynchronize(0));
}
+
+template <typename T>
+__global__ void gSwap(T* d_v1, T* d_v2, int size) {
+ auto threadsPerBlock = blockDim.x;
+ int index = threadIdx.x + threadsPerBlock * blockIdx.x;
+ if(index < size) {
+ T temp = d_v1[index];
+ d_v1[index] = d_v2[index];
+ d_v2[index] = temp;
+ }
+}
+
+template <typename T>
+void swap_ranges(Ptr<Backend> backend, T* begin, T* end, T* dest) {
+ int size = end - begin;
+ if (size == 0)
+ return;
+
+ CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
+ int threadsPerBlock = std::min(MAX_THREADS, size);
+ int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)
+ gSwap<<<blocks, threadsPerBlock>>>(begin, dest, size);
+ CUDA_CHECK(cudaStreamSynchronize(0));
+}
+
+// clang-format off
+template void swap_ranges<int8_t>(Ptr<Backend>, int8_t*, int8_t*, int8_t*);
+template void swap_ranges<int16_t>(Ptr<Backend>, int16_t*, int16_t*, int16_t*);
+template void swap_ranges<int32_t>(Ptr<Backend>, int32_t*, int32_t*, int32_t*);
+template void swap_ranges<int64_t>(Ptr<Backend>, int64_t*, int64_t*, int64_t*);
+
+template void swap_ranges<uint8_t>(Ptr<Backend>, uint8_t*, uint8_t*, uint8_t*);
+template void swap_ranges<uint16_t>(Ptr<Backend>, uint16_t*, uint16_t*, uint16_t*);
+template void swap_ranges<uint32_t>(Ptr<Backend>, uint32_t*, uint32_t*, uint32_t*);
+template void swap_ranges<uint64_t>(Ptr<Backend>, uint64_t*, uint64_t*, uint64_t*);
+
+template void swap_ranges<char>(Ptr<Backend>, char*, char*, char*);
+template void swap_ranges<float>(Ptr<Backend>, float*, float*, float*);
+template void swap_ranges<double>(Ptr<Backend>, double*, double*, double*);
+// clang-format on
+
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/algorithm.h b/src/tensors/gpu/algorithm.h
index 9b4480e9..84f8b41e 100644
--- a/src/tensors/gpu/algorithm.h
+++ b/src/tensors/gpu/algorithm.h
@@ -10,6 +10,9 @@ void copy(Ptr<marian::Backend> backend, const T* begin, const T* end, T* dest);
template <typename T>
void fill(Ptr<marian::Backend> backend, T* begin, T* end, T value);
+template <typename T>
+void swap_ranges(Ptr<marian::Backend> backend, T* begin, T* end, T* dest);
+
void setSparse(Ptr<marian::Backend> backend,
const std::vector<size_t>&,
const std::vector<float>&,
diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h
index c40b94e0..9721670b 100755
--- a/src/tensors/tensor.h
+++ b/src/tensors/tensor.h
@@ -56,15 +56,6 @@ public:
virtual size_t size() { return shape_.elements(); }
- virtual float scalar() {
- ABORT_IF(!matchType<float>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<float>(),
- type_);
- ABORT_IF(size() != 1, "Tensor is not a scalar");
- return get(0);
- }
-
template <typename T>
T scalar() {
ABORT_IF(!matchType<T>(type_),
@@ -76,6 +67,10 @@ public:
return get<T>(0);
}
+ virtual float scalar() {
+ return scalar<float>();
+ }
+
Ptr<Backend> getBackend() { return backend_; }
DeviceId getDeviceId() { return backend_->getDeviceId(); }
@@ -85,24 +80,6 @@ public:
return New<TensorBase>(mem, Shape{1, (int)size}, backend_);
}
- float get(size_t i) {
- ABORT_IF(!matchType<float>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<float>(),
- type_);
-
- float temp = 0; // (initialize to keep compiler happy)
- if(backend_->getDeviceId().type == DeviceType::cpu) {
- std::copy(data() + i, data() + i + 1, &temp);
- }
-#ifdef CUDA_FOUND
- else {
- gpu::copy(backend_, data() + i, data() + i + 1, &temp);
- }
-#endif
- return temp;
- }
-
template <typename T>
T get(size_t i) {
ABORT_IF(!matchType<T>(type_),
@@ -122,6 +99,10 @@ public:
return temp;
}
+ float get(size_t i) {
+ return get<float>(i);
+ }
+
template <typename T>
void set(size_t i, T value) {
ABORT_IF(!matchType<T>(type_),
@@ -228,24 +209,95 @@ public:
#endif
}
+ template <typename T>
void copyFrom(Tensor in) {
- // @TODO: solve this later
- ABORT_IF(!matchType<float>(type_),
+ ABORT_IF(in->shape() != shape_, "Can only copy tensors with equal shapes ({} != {})", in->shape(), shape_);
+ ABORT_IF(in->type() != type_, "Can only copy tensors with equal types ({} != {})", in->type(), type_);
+ ABORT_IF(!matchType<T>(type_),
"Requested type ({}) and underlying type ({}) do not match",
- request<float>(),
+ request<T>(),
type_);
if(in->getBackend()->getDeviceId().type == DeviceType::cpu
&& backend_->getDeviceId().type == DeviceType::cpu) {
- std::copy(in->data(), in->data() + in->size(), data());
+ std::copy(in->data<T>(), in->data<T>() + in->size(), data<T>());
}
#ifdef CUDA_FOUND
else {
- gpu::copy(backend_, in->data(), in->data() + in->size(), data());
+ gpu::copy(backend_, in->data<T>(), in->data<T>() + in->size(), data<T>());
}
#endif
}
+ void copyFrom(Tensor in) {
+ switch(type_) {
+ case Type::int8: copyFrom<int8_t>(in); break;
+ case Type::int16: copyFrom<int16_t>(in); break;
+ case Type::int32: copyFrom<int32_t>(in); break;
+ case Type::int64: copyFrom<int64_t>(in); break;
+
+ case Type::uint8: copyFrom<uint8_t>(in); break;
+ case Type::uint16: copyFrom<uint16_t>(in); break;
+ case Type::uint32: copyFrom<uint32_t>(in); break;
+ case Type::uint64: copyFrom<uint64_t>(in); break;
+
+ case Type::float32: copyFrom<float>(in); break;
+ case Type::float64: copyFrom<double>(in); break;
+
+ default: ABORT("Unknown type {}", type_);
+ }
+ }
+
+ // Swaps the contents of the current tensor with the argument tensor
+ template <typename T>
+ void swap(Tensor swapee) {
+ ABORT_IF(swapee->shape() != shape_, "Can only swap tensors with equal shapes ({} != {})", swapee->shape(), shape_);
+ ABORT_IF(swapee->type() != type_, "Can only swap tensors with equal types ({} != {})", swapee->type(), type_);
+ ABORT_IF(!matchType<T>(type_),
+ "Requested type ({}) and underlying type ({}) do not match",
+ request<T>(),
+ type_);
+
+ // we live on CPUs; just use stdlib
+ if(swapee->getBackend()->getDeviceId().type == DeviceType::cpu
+ && backend_->getDeviceId().type == DeviceType::cpu) {
+ std::swap_ranges(swapee->data<T>(), swapee->data<T>() + swapee->size(), data<T>());
+ }
+#ifdef CUDA_FOUND
+ else {
+ if(backend_->getDeviceId() == swapee->getBackend()->getDeviceId()) {
+ // we live on the same GPU; do an element-wise swap
+ gpu::swap_ranges(backend_, swapee->data<T>(), swapee->data<T>() + swapee->size(), data<T>());
+ } else {
+ // we live on two different GPUs or devices; go through CPU RAM
+ std::vector<T> temp;
+ get(temp);
+ copyFrom(swapee);
+ swapee->set(temp);
+ }
+ }
+#endif
+ }
+
+ void swap(Tensor swapee) {
+ switch(type_) {
+ case Type::int8: swap<int8_t>(swapee); break;
+ case Type::int16: swap<int16_t>(swapee); break;
+ case Type::int32: swap<int32_t>(swapee); break;
+ case Type::int64: swap<int64_t>(swapee); break;
+
+ case Type::uint8: swap<uint8_t>(swapee); break;
+ case Type::uint16: swap<uint16_t>(swapee); break;
+ case Type::uint32: swap<uint32_t>(swapee); break;
+ case Type::uint64: swap<uint64_t>(swapee); break;
+
+ case Type::float32: swap<float>(swapee); break;
+ case Type::float64: swap<double>(swapee); break;
+
+ default: ABORT("Unknown type {}", type_);
+ }
+ }
+
template <typename T>
std::string debug() {
ABORT_IF(!matchType<T>(type_),
diff --git a/src/training/communicator.h b/src/training/communicator.h
index 2d42e109..eef136f2 100755
--- a/src/training/communicator.h
+++ b/src/training/communicator.h
@@ -7,9 +7,9 @@
#include "optimizers/optimizers.h"
#if MPI_FOUND
#ifdef __GNUC__
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wsuggest-override"
-#endif
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wsuggest-override"
+#endif
#include "mpi.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
@@ -203,25 +203,18 @@ public:
void swapParams(const std::vector<Tensor>& paramShards) const override {
// Update all graphs with parameter shard
- ABORT_IF(graphs_.size() < 2, "Swap requires at least two graphs");
-
+
auto gather = [this, paramShards](size_t idx, size_t begin, size_t end) {
- ABORT_IF(end-begin != paramShards[idx]->size(), "inconsistent shard size (swapParams, [{}], {} vs {})??", idx, end-begin, paramShards[idx]->size());
+ ABORT_IF(end - begin != paramShards[idx]->size(), "inconsistent shard size (swapParams, [{}], {} vs {})??", idx, end-begin, paramShards[idx]->size());
// Copy parameter shard to each graph, apart from last graph
for(int i = 0; i < (int)graphs_.size() - 1; ++i) {
- auto subParam
- = graphs_[i]->params()->vals()->subtensor(begin, paramShards[idx]->size());
+ auto subParam = graphs_[i]->params()->vals()->subtensor(begin, paramShards[idx]->size());
subParam->copyFrom(paramShards[idx]);
}
- // Back-up shard from last graph
- auto subParamLast =
- graphs_.back()->params()->vals()->subtensor(begin, paramShards[idx]->size());
- paramShards[idx]->copyFrom(subParamLast);
-
- auto subParamFirst
- = graphs_[0]->params()->vals()->subtensor(begin, paramShards[idx]->size());
- subParamLast->copyFrom(subParamFirst);
+ // Swap shard with corresponding share from last graph
+ auto subParamLast = graphs_.back()->params()->vals()->subtensor(begin, paramShards[idx]->size());
+ paramShards[idx]->swap(subParamLast);
};
// Execute for each shard
foreach(gather);