Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-12-07 02:20:44 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-12-07 02:20:44 +0300
commite8ea37cd5b85e3df817b9ced68bef9cc64b45d16 (patch)
tree1ccd3d3b7112615a6fae93a878c4c4523a5add2c
parentbbc673c50fbf2faa90bdc44003d15087632262bc (diff)
Merged PR 21648: Allow for dynamic gradient scaling to fade out after N updates
Allow for dynamic gradient scaling to fade out after N updates
-rwxr-xr-xsrc/tensors/gpu/prod.cpp6
-rw-r--r--src/training/graph_group.cpp17
-rw-r--r--src/training/graph_group.h1
3 files changed, 20 insertions, 4 deletions
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp
index bf0d2395..c72af4db 100755
--- a/src/tensors/gpu/prod.cpp
+++ b/src/tensors/gpu/prod.cpp
@@ -562,7 +562,11 @@ void ProdBatchedLegacy(marian::Tensor C,
ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
#if COMPILE_FP16
} else if(C->type() == Type::float16) { // not a *.cu file
- ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+ // we use computeType=float here for fp16 training as this seems more stable and roughly as fast
+ ProdBatchedTypedLegacy<half, float>(C, allocator, A, B, transA, transB, beta, scalar);
+
+ // original for reference:
+ // ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
#endif
} else {
ABORT("ProdBatchedLegacy not implemented for element type {}", C->type());
diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp
index 03e5acf4..59cd4b6d 100644
--- a/src/training/graph_group.cpp
+++ b/src/training/graph_group.cpp
@@ -31,11 +31,16 @@ GraphGroup::GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
if(vgc.size() > 0) dynamicGradientScalingFactor_ = std::stof(vgc[0]);
if(vgc.size() > 1) dynamicGradientScalingUseLogs_ = vgc[1] == "log";
+ if(vgc.size() > 2) dynamicGradientScalingFadeout_ = std::stoul(vgc[2]);
LOG_ONCE(info,
"Re-scaling gradient to have average gradient norm if (log={}) gradient norm diverges from average by {} sigmas",
dynamicGradientScalingUseLogs_,
dynamicGradientScalingFactor_);
+ if(dynamicGradientScalingFadeout_ > 0)
+ LOG_ONCE(info,
+ "Dynamic gradient re-scaling will fade out linearly after {} updates",
+ dynamicGradientScalingFadeout_);
}
if(options_->get<bool>("check-gradient-nan")) {
@@ -229,11 +234,17 @@ float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords)
auto deltaTransform = gNormTransform - gNormAvgTransform; // compute the difference between the current transformer gradient norm and the running average.
auto gNormStdTransform = std::sqrt(gNormVarTransform); // compute STD for the running average of (log) gradient norms.
+ float fadeoutMultiplier = 1.f;
+ if(dynamicGradientScalingFadeout_ > 0ul) // fade out linearly after that many updates @TODO: allow units other than updates
+ fadeoutMultiplier = (float)std::max(dynamicGradientScalingFadeout_, scheduler_->numberOfBatches()) / (float)dynamicGradientScalingFadeout_;
+
+ float dynamicGradientScalingFactorWithFadeout = dynamicGradientScalingFactor_ * fadeoutMultiplier; // if fadeoutMultiplier increases dynamic gradient scaling becomes less and less likely to happen over time.
// delta of (log) gradient norm vs (log) gradient norm average is larger than N standard deviations
// hence rescale gradient using the average.
- if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactor_ * gNormStdTransform) {
- LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f}",
- dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactor_, gNormStdTransform);
+ if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactorWithFadeout * gNormStdTransform) {
+ if(isMainProcess())
+ LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f} - scaling gradient by {:.4f}",
+ dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactorWithFadeout, gNormStdTransform, gNormAvg / gNorm);
normalizationFactor *= gNorm / gNormAvg; // since we later do gradient / normalizationFactor this divides by norm and multiplies by the average, rescaling to the average.
}
diff --git a/src/training/graph_group.h b/src/training/graph_group.h
index b7f2f7ef..aa68922a 100644
--- a/src/training/graph_group.h
+++ b/src/training/graph_group.h
@@ -74,6 +74,7 @@ protected:
bool dynamicGradientScaling_{false};
float dynamicGradientScalingFactor_{2.f};
bool dynamicGradientScalingUseLogs_{false};
+ size_t dynamicGradientScalingFadeout_{0ul};
// determines the number of input streams (i.e. input files or fields in the TSV input) that need
// to be included in the batch, i.e. without alignments and weights