Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-10-26 23:25:39 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-10-26 23:25:39 +0300
commit1404201926b5b4e27993776d52dfac809e8556f4 (patch)
tree10d4cda76a78a3a3f607b543fce6602367ab6487
parent7f06f3c5d2035dac0cb4349bf29fbfa3e6bb5448 (diff)
Merged PR 21151: Cleaning up fp16 behavior
This PR improves clipping and pruning behavior of NaNs and Infs during fp16 training, ultimately avoiding the underflow problems that we were facing so far.
-rw-r--r--src/common/aliases.cpp4
-rw-r--r--src/common/config_parser.cpp6
-rw-r--r--src/common/definitions.h10
-rw-r--r--src/models/transformer.h15
-rwxr-xr-xsrc/tensors/cpu/tensor_operators.cpp4
-rwxr-xr-xsrc/tensors/gpu/element.cu12
-rw-r--r--src/tensors/gpu/tensor_operators.cu147
-rw-r--r--src/tensors/tensor_operators.h19
-rw-r--r--src/training/graph_group.cpp118
-rw-r--r--src/training/graph_group.h17
-rw-r--r--src/training/graph_group_async.cpp6
-rw-r--r--src/training/graph_group_singleton.cpp8
-rw-r--r--src/training/graph_group_sync.cpp8
13 files changed, 233 insertions, 141 deletions
diff --git a/src/common/aliases.cpp b/src/common/aliases.cpp
index 0be26a8c..99574fe1 100644
--- a/src/common/aliases.cpp
+++ b/src/common/aliases.cpp
@@ -29,8 +29,8 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) {
cli.alias("fp16", "true", [&](YAML::Node& config) {
if(mode_ == cli::mode::training) {
config["precision"] = std::vector<std::string>({"float16", "float32"}); // inference type, optimization type, save type
- // scaling factor (power of 2), frequency, multiplier at increase, tolerance, range, minium factor
- config["cost-scaling"] = std::vector<std::string>({"0", "1000", "2", "0.05", "10", "1e-5"});
+ // scaling factor, frequency, multiplier at increase, minium scaling factor
+ config["cost-scaling"] = std::vector<std::string>({"256.f", "1000", "2.f", "256.f"});
} else {
config["precision"] = std::vector<std::string>({"float16"}); // for inference we do not need the other types
}
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index b3e8950b..51764cdc 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -522,15 +522,15 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
// mixed precision training
cli.add<bool>("--fp16",
"Shortcut for mixed precision training with float16 and cost-scaling, "
- "corresponds to: --precision float16 float32 --cost-scaling 0 1000 2 0.05 10 1e-5f");
+ "corresponds to: --precision float16 float32 --cost-scaling 256.f 1000 2.f 256.f");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision training for forward/backward pass and optimizaton. "
"Defines types for: forward/backward pass, optimization.",
{"float32", "float32"});
cli.add<std::vector<std::string>>("--cost-scaling",
"Dynamic cost scaling for mixed precision training: "
- "power of 2, scaling window, scaling factor, tolerance, range, minimum factor")
- ->implicit_val("0.f 1000 2.f 0.05f 10 1e-5f");
+ "scaling factor, frequency, multiplier, minimum factor")
+ ->implicit_val("256.f 1000 2.f 256.f");
cli.add<size_t>("--gradient-norm-average-window",
"Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). "
"After this many updates about 90% of the mass of the exponential average comes from these updates",
diff --git a/src/common/definitions.h b/src/common/definitions.h
index d2cf8aa4..d8a3ad46 100644
--- a/src/common/definitions.h
+++ b/src/common/definitions.h
@@ -106,24 +106,24 @@ using Weak = std::weak_ptr<T>;
/** @brief Creates shared_ptr of any type, passes all arguments to any available
* constructor */
template <class T, typename... Args>
-Ptr<T> New(Args&&... args) {
- return Ptr<T>(new T(std::forward<Args>(args)...));
+inline Ptr<T> New(Args&&... args) {
+ return std::make_shared<T>(std::forward<Args>(args)...);
}
template <class T>
-Ptr<T> New(Ptr<T> p) {
+inline Ptr<T> New(Ptr<T> p) {
return Ptr<T>(p);
}
/** @brief Creates InstrusivePtr of any type, passes all arguments to any available
* constructor */
template <class T, typename... Args>
-IPtr<T> INew(Args&&... args) {
+inline IPtr<T> INew(Args&&... args) {
return IPtr<T>(new T(std::forward<Args>(args)...));
}
template <class T>
-IPtr<T> INew(Ptr<T> p) {
+inline IPtr<T> INew(Ptr<T> p) {
return IPtr<T>(p);
}
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 2393ad73..b2c0f6be 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -147,8 +147,7 @@ public:
int dimDepth = dimModel / dimHeads;
- auto output
- = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});
+ auto output = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});
return transpose(output, {0, 2, 1, 3}); // [dimBatch*dimBeam, dimHeads, dimSteps, dimDepth]
}
@@ -361,9 +360,9 @@ public:
Expr LayerAttention(std::string prefix,
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
- const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
- const Expr& values, // ...?
- const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
+ Expr keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+ Expr values, // ...?
+ Expr mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
int dimHeads,
bool cache = false,
bool saveAttentionWeights = false) {
@@ -373,6 +372,12 @@ public:
auto opsPre = opt<std::string>("transformer-preprocess");
auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb);
+ // fixes missing norm for keys and values in self-attention with pre-norm
+ if(input == keys)
+ keys = output;
+ if(input == values)
+ values = output;
+
// multi-head self-attention over previous input
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index 1afb8f64..f3964f91 100755
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -24,6 +24,10 @@ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, b
ABORT("Not implemented");
}
+bool SanitizeGradient(marian::Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool /*pruneNaN*/, bool /*clipInf*/) {
+ ABORT("Not implemented");
+}
+
template <bool add, typename To, typename From>
void CopyCastTo(To* out, const From* in, int length) {
for(int i = 0; i < length; ++i)
diff --git a/src/tensors/gpu/element.cu b/src/tensors/gpu/element.cu
index 6790efd4..e9cbe081 100755
--- a/src/tensors/gpu/element.cu
+++ b/src/tensors/gpu/element.cu
@@ -29,7 +29,9 @@ __global__ void gElement(
indices[i] = tensors[i].shape().bindex(dims);
}
- tensors[0].data()[index] = functional::apply(functor, tensors, indices);
+ // This performs the internal application of the functor in float32 regardless of the input type.
+ // It seems there are no speed penalties but improved precision.
+ tensors[0].data()[index] = (T)functional::applyWithCast<float>(functor, tensors, indices);
}
}
}
@@ -65,13 +67,7 @@ void Element(Functor functor, Tensor out, Tensors... tensors) {
ElementTyped<float>(functor, out, tensors...);
} else if(out->type() == Type::float16) {
#if COMPILE_FP16
- std::vector<marian::Tensor> ts({out, tensors...});
- bool div2 = std::all_of(ts.cbegin(), ts.cend(), [](marian::Tensor t){ return t->shape()[-1] % 2 == 0; });
- if(div2) {
- ElementTyped<halfx2>(functor, out, tensors...);
- } else {
- ElementTyped<half>(functor, out, tensors...);
- }
+ ElementTyped<half>(functor, out, tensors...);
#else
ABORT("FP16 not supported with chosen current hardware or CUDA version");
#endif
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index d55214bc..1347c3bb 100644
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -16,15 +16,12 @@ namespace gpu {
namespace atomics {
static inline __device__ void atomicAdd(float *address, float val) {
- //*address += val;
::atomicAdd(address, val);
}
#if COMPILE_FP16
// @TODO: copied from CuTorch, adapt this better, give credit.
static inline __device__ void atomicAdd(half *address, half val) {
- //*address += val;
-
#if __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000 // compute capability 70 and higher with CUDA 10
::atomicAdd(address, val);
#else // __CUDA_ARCH__ < 700
@@ -50,7 +47,8 @@ static inline __device__ void atomicAdd(half *address, half val) {
} while (assumed != old);
#endif // __CUDA_ARCH__
}
-#endif
+#endif // COMPILE_FP16
+
}
@@ -96,6 +94,81 @@ void IsNaN(const Tensor in, Ptr<Allocator> allocator, bool& isNaN, bool& isInf)
cudaStreamSynchronize(0);
}
+template <typename T>
+__global__ void gSanitizeGradient(T* in, int length,
+ bool* isNaN, bool* isInf,
+ bool pruneNaN, bool clipInf,
+ float forNaN = 0.f, float forInf = 65504.f, float forInfNeg = -65504.f) {
+ for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
+ int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
+ if(index < length) {
+ float v = (float)in[index];
+ // handle NaN
+ if(isnan(v)) {
+ if(pruneNaN) {
+ in[index] = (T)forNaN;
+ } else {
+ *isNaN = true;
+ }
+ }
+ // handle +/- Inf
+ if(isinf(v)) {
+ if(clipInf) {
+ in[index] = v > 0 ? (T)forInf : (T)forInfNeg;
+ } else {
+ *isInf = true;
+ }
+ }
+ }
+ }
+}
+
+// This function is meant to clean gradients, i.e. clip infinities and prune NaNs if required.
+// If all NaNs and Infs have been removed we return `true` for indicating a sane gradient.
+// If `clipInf` is set, infinities are replaced with the maximum/minimum non-inf value for the tensor.
+// In that case infinities do not result in a bad gradient, since they get clipped.
+// If `pruneNaN` is set, NaNs are replaced with 0. Since NaNs get removed now they do not result
+// in a bad gradient.
+// If NaNs or infinities are detected but not removed (either because of `pruneNaN=false` or `clipInf=false`),
+// we return `false` indicating a bad gradient.
+bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) {
+ cudaSetDevice(in->getDeviceId().no);
+
+ int length = in->size();
+
+ int threads = std::min(MAX_THREADS, length);
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+
+ auto mem = allocator->alloc<bool>(2);
+ bool* dIsNaN = &mem->data<bool>()[0];
+ bool* dIsInf = &mem->data<bool>()[1];
+ fill(in->getBackend(), dIsNaN, dIsNaN + 2, false);
+
+ float forNaN = 0.f;
+ float forInf = NumericLimits<float>(in->type()).max;
+ float forInfNeg = NumericLimits<float>(in->type()).lowest;
+
+ if(in->type() == Type::float32) {
+ gSanitizeGradient<<<blocks, threads>>>(in->data<float>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gSanitizeGradient<<<blocks, threads>>>(in->data<half>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg);
+#endif
+ } else {
+ ABORT("gSanitizeGradient for type {} not implemented", in->type());
+ }
+
+ bool isNaN, isInf;
+ CudaCopy(dIsNaN, dIsNaN + 1, &isNaN);
+ CudaCopy(dIsInf, dIsInf + 1, &isInf);
+
+ allocator->free(mem);
+
+ cudaStreamSynchronize(0);
+
+ return !isNaN && !isInf;
+}
+
template <bool add, typename To, typename From>
__global__ void gCopyCastTo(To* out, const From* in, int length) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
@@ -1090,7 +1163,7 @@ void PasteRows(Tensor out,
size_t rowsToCopy = indices->size();
int threads = std::min(MAX_THREADS, (int)cols);
-#if 1 // @TODO: make this configurable with a 'deterministic' flag
+#if 0 // @TODO: make this configurable with a 'deterministic' flag
// If we only use one block, then each core operates on a different column,
// hence the summation becomes deterministic.
// However, we only use e.g. 512 cores out of possibly 3000+, so this will be
@@ -1355,7 +1428,7 @@ __global__ void gGRUFastForward(T* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- T m = !mask || mask[j];
+ float m = !mask || mask[j];
T* rowOut = out + j * cols;
const T* rowState = state + j * cols;
@@ -1365,21 +1438,21 @@ __global__ void gGRUFastForward(T* out,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
- T r = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
+ float r = functional::Ops<float>::sigmoid((float)xWrow[i] + (float)sUrow[i] + (float)b[i]);
int k = i + cols;
- T z = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float z = functional::Ops<float>::sigmoid((float)xWrow[k] + (float)sUrow[k] + (float)b[k]);
int l = i + 2 * cols;
- T h;
+ float h;
if(final)
- h = functional::Ops<T>::tanh(xWrow[l] + (sUrow[l] + b[l]) * r);
+ h = functional::Ops<float>::tanh((float)xWrow[l] + ((float)sUrow[l] + (float)b[l]) * r);
else
- h = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] * r + b[l]);
+ h = functional::Ops<float>::tanh((float)xWrow[l] + (float)sUrow[l] * r + (float)b[l]);
- T out = ((T)1.f - z) * h + z * rowState[i];
- rowOut[i] = m * out + ((T)1.f - m) * rowState[i];
+ float out = (1.f - z) * h + z * (float)rowState[i];
+ rowOut[i] = (T)(m * out + (1.f - m) * (float)rowState[i]);
}
}
}
@@ -1441,7 +1514,7 @@ __global__ void gGRUFastBackward(T* outState,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- T m = !mask || mask[j];
+ float m = !mask || mask[j];
T* rowOutState = outState + j * cols;
T* rowOutXW = outXW + j * cols * 3;
@@ -1459,56 +1532,56 @@ __global__ void gGRUFastBackward(T* outState,
int k = i + cols;
int l = i + 2 * cols;
- T r = functional::Ops<T>::sigmoid(rowXW[i] + rowSU[i] + b[i]);
- T z = functional::Ops<T>::sigmoid(rowXW[k] + rowSU[k] + b[k]);
+ float r = functional::Ops<float>::sigmoid((float)rowXW[i] + (float)rowSU[i] + (float)b[i]);
+ float z = functional::Ops<float>::sigmoid((float)rowXW[k] + (float)rowSU[k] + (float)b[k]);
- T h;
+ float h;
if(final)
- h = functional::Ops<T>::tanh(rowXW[l] + (rowSU[l] + b[l]) * r);
+ h = functional::Ops<float>::tanh((float)rowXW[l] + ((float)rowSU[l] + (float)b[l]) * r);
else
- h = functional::Ops<T>::tanh(rowXW[l] + rowSU[l] * r + b[l]);
+ h = functional::Ops<float>::tanh((float)rowXW[l] + (float)rowSU[l] * r + (float)b[l]);
- T adj = rowAdj[i];
+ float adj = rowAdj[i];
- T t = ((T)1.f - z) * ((T)1.f - h * h);
+ float t = (1.f - z) * (1.f - h * h);
// df/ds
if(outState)
- rowOutState[i] += (m * z - m + (T)1.f) * adj;
+ rowOutState[i] += (T)((m * z - m + 1.f) * adj);
// df/d(xW_r) ...
- T dfdxW_r = m * r * ((T)1.f - r) * t * adj;
+ float dfdxW_r = m * r * (1.f - r) * t * adj;
if(final)
- dfdxW_r *= rowSU[l] + b[l];
+ dfdxW_r *= (float)rowSU[l] + (float)b[l];
else
- dfdxW_r *= rowSU[l];
+ dfdxW_r *= (float)rowSU[l];
if(outXW)
- rowOutXW[i] += dfdxW_r;
+ rowOutXW[i] += (T)dfdxW_r;
if(outSU)
- rowOutSU[i] += dfdxW_r;
+ rowOutSU[i] += (T)dfdxW_r;
if(outB)
- rowOutB[i] += dfdxW_r;
+ rowOutB[i] += (T)dfdxW_r;
// df/d(xW_z) ...
- T dfdxW_z = m * ((T)1.f - z) * z * (rowState[i] - h) * adj;
+ float dfdxW_z = m * (1.f - z) * z * ((float)rowState[i] - h) * adj;
if(outXW)
- rowOutXW[k] += dfdxW_z;
+ rowOutXW[k] += (T)dfdxW_z;
if(outSU)
- rowOutSU[k] += dfdxW_z;
+ rowOutSU[k] += (T)dfdxW_z;
if(outB)
- rowOutB[k] += dfdxW_z;
+ rowOutB[k] += (T)dfdxW_z;
// df/d(xW_x) ...
- T dfdxW_x = m * t * adj;
+ float dfdxW_x = m * t * adj;
if(outXW)
- rowOutXW[l] += dfdxW_x;
+ rowOutXW[l] += (T)dfdxW_x;
if(outSU)
- rowOutSU[l] += dfdxW_x * r;
+ rowOutSU[l] += (T)(dfdxW_x * r);
if(outB)
if(final)
- rowOutB[l] += dfdxW_x * r;
+ rowOutB[l] += (T)(dfdxW_x * r);
else
- rowOutB[l] += dfdxW_x;
+ rowOutB[l] += (T)dfdxW_x;
}
}
}
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 6e587953..dc29bf35 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -41,6 +41,25 @@ DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor);
DISPATCH2(AddCast, marian::Tensor, const marian::Tensor);
DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&);
+#ifdef CUDA_FOUND
+namespace gpu {
+bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
+}
+#endif
+
+namespace cpu {
+bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
+}
+
+static inline bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) {
+#ifdef CUDA_FOUND
+ if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
+ return gpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
+ else
+#endif
+ return cpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
+}
+
template <class Functor, class... Tensors>
void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND
diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp
index e9c977b9..03e5acf4 100644
--- a/src/training/graph_group.cpp
+++ b/src/training/graph_group.cpp
@@ -10,25 +10,19 @@ GraphGroup::GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
mbRoundUp_(options_->get<bool>("mini-batch-round-up", true)) {
if(options_->hasAndNotEmpty("cost-scaling")) {
auto vcs = options_->get<std::vector<std::string>>("cost-scaling");
- costScale_ = true;
- float costExponent = std::stof(vcs[0]);
- costScaleFactor_ = std::pow(2.0f, costExponent);
-
- if(vcs.size() > 1) costScaleFreq_ = std::stoul(vcs[1]);
- if(vcs.size() > 2) costScaleMultiplier_ = std::stof(vcs[2]);
- if(vcs.size() > 3) costScaleNanTolerance_ = std::stof(vcs[3]);
- if(vcs.size() > 4) costScaleNanRange_ = std::stoul(vcs[4]);
- if(vcs.size() > 5) costScaleFactorMinimum_ = std::stof(vcs[5]);
+
+ costScaling_ = true;
+ costScalingFactor_ = std::stof( vcs[0]);
+ if(vcs.size() > 1) costScalingFreq_ = std::stoul(vcs[1]);
+ if(vcs.size() > 2) costScalingMultiplier_ = std::stof( vcs[2]);
+ if(vcs.size() > 3) costScalingFactorMinimum_ = std::stof( vcs[3]);
LOG_ONCE(info,
- "Training with cost scaling - factor: 2^{} = {}, frequency: {}, multiplier: {}, tolerance: {}, range: {}, minimum: {}",
- costExponent,
- costScaleFactor_,
- costScaleFreq_,
- costScaleMultiplier_,
- costScaleNanTolerance_,
- costScaleNanRange_,
- costScaleFactorMinimum_);
+ "Training with cost scaling - factor: {}, frequency: {}, multiplier: {}, minimum: {}",
+ costScalingFactor_,
+ costScalingFreq_,
+ costScalingMultiplier_,
+ costScalingFactorMinimum_);
}
if(options_->hasAndNotEmpty("dynamic-gradient-scaling")) {
@@ -96,21 +90,17 @@ void GraphGroup::initGraphsAndOpts() {
// given number of iterations. Usually we increase by 2 which adds
// one more bit for precision.
void GraphGroup::increaseCostScaleFactor() {
- if(!costScale_)
+ if(!costScaling_)
return;
noNanSeen_++;
size_t total = nanSeen_ + noNanSeen_;
- float nanPercent = noNanSeen_ == (float)nanSeen_ / (float)total; // total is at least 1 because of noNanSeen_++
- if(noNanSeen_ % costScaleFreq_ == 0) {
- costScaleFactor_ *= costScaleMultiplier_;
- LOG(debug,
- "NaN/Inf percentage {:.2f} after {} gradient updates. Increasing cost-scaling factor to {}",
- nanPercent,
- total,
- costScaleFactor_);
+ if(noNanSeen_ % costScalingFreq_ == 0) {
+ costScalingFactor_ *= costScalingMultiplier_;
+ if(isMainProcess())
+ LOG(debug, "No NaN/Inf after {} gradient updates. Increasing cost-scaling factor to {}", total, costScalingFactor_);
// Resetting counts after cost-scale change
noNanSeen_ = 0;
@@ -120,48 +110,56 @@ void GraphGroup::increaseCostScaleFactor() {
// call when a NaN was seen to decrease cost-scaling factor
void GraphGroup::decreaseCostScaleFactor() {
- if(!costScale_)
+ if(!costScaling_)
return;
nanSeen_++;
size_t total = nanSeen_ + noNanSeen_;
- float nanPercent = (float)nanSeen_ / (float)total; // total is at least 1 because of nanSeen_++
- if(total >= costScaleNanRange_ && nanPercent > costScaleNanTolerance_) {
- if(costScaleFactor_ > costScaleFactorMinimum_) {
- costScaleFactor_ /= costScaleMultiplier_;
- LOG(debug,
- "NaN/Inf percentage {:.2f} in {} gradient updates, reducing cost-scaling factor to {}",
- nanPercent,
- total,
- costScaleFactor_);
- } else {
- // @TODO: think if should this rather abort?
- LOG(warn,
- "NaN/Inf percentage {:.2f} in {} gradient updates, but cost-scaling factor {} is already at minimum",
- nanPercent,
- total,
- costScaleFactor_);
- }
- // Resetting counts after cost-scale change
- noNanSeen_ = 0;
- nanSeen_ = 0;
+ // do not reduce cost-scaling factor below minimum
+ if(costScalingFactor_ > costScalingFactorMinimum_)
+ costScalingFactor_ /= costScalingMultiplier_;
+
+ if(isMainProcess()) {
+ if(costScalingFactor_ > costScalingFactorMinimum_)
+ LOG(debug, "Seen NaN/Inf after {} gradient updates. Reduced cost-scaling factor to {}", total, costScalingFactor_);
+ else
+ LOG(debug, "Seen NaN/Inf after {} gradient updates, Reduced cost-scaling factor to minimum {}. Pruning NaNs now.", total, costScalingFactor_);
}
+
+ // Resetting counts after cost-scale change
+ noNanSeen_ = 0;
+ nanSeen_ = 0;
}
float GraphGroup::checkNanOrNorm(size_t i, size_t begin, size_t end) {
auto curGrad = graphs_[i]->params()->grads()->subtensor(begin, end-begin);
- if(checkGradientNan_ || costScale_) {
- bool hasNan = false, hasInf = false;
- IsNaN(curGrad, graphs_[i]->allocator(), hasNan, hasInf); // @TODO: make safe with different compiler options
- if(hasNan || hasInf) {
- LOG(debug, "Found Nan ({}) or Inf ({})", hasNan, hasInf);
+ // If costScaling_ then check for NaN values if the costScalingFactor_ is larger than
+ // the minimum. If a NaN value is seen we exit here and will reduce the factor next and
+ // this skips an update.
+ // If costScalingFactor_ is already at the minimum, prune the NaN values away. This replaces
+ // NaNs with 0. Updates are not skipped any more.
+ // Regardless of NaNs, we clip +/-inf to the largest corresponding values for the gradient value type.
+ // This changes the gradient but seems to be quite stable. In effect, for fp16 this is equivalent
+ // to gradient clipping at (65504.f / costScalingFactor_) which in most cases is still large.
+ if(costScaling_ || checkGradientNan_) {
+ bool pruneNaN = !checkGradientNan_ && costScalingFactor_ == costScalingFactorMinimum_;
+ bool clipInf = !checkGradientNan_;
+ bool saneGradient = SanitizeGradient(curGrad, graphs_[i]->allocator(), pruneNaN, clipInf);
+
+ // This should never happen, if it does, something is wrong with the kernel above and needs to be fixed.
+ ABORT_IF(pruneNaN && clipInf && !saneGradient, "We are removing NaNs and clipping Infs, but gradient is still not sane??");
+
+ if(!saneGradient) {
+ LOG(debug, "Found NaN");
return std::numeric_limits<float>::quiet_NaN();
}
}
-
+
+ // The optional clipping above will affect the norm here. The norm can be non-finite despite the above
+ // gradient sanitization, hence check again and propagate a NaN.
if(dynamicGradientScaling_) {
auto gNorm = L2Norm(curGrad, graphs_[i]->allocator());
if(isFinite(gNorm) && gNorm > 0.0)
@@ -197,8 +195,8 @@ float GraphGroup::executeAndCollectNorm(const std::function<float(size_t, size_t
float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords) {
float normalizationFactor = 1.f;
- if(costScale_)
- normalizationFactor *= costScaleFactor_;
+ if(costScaling_)
+ normalizationFactor *= costScalingFactor_;
if(options_->get<bool>("normalize-gradient"))
normalizationFactor *= updateTrgWords;
@@ -207,9 +205,9 @@ float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords)
return normalizationFactor;
if(dynamicGradientScaling_) {
- // make gradient norm invariant to changes in costScaleFactor_, luckily norm(c * g) = c * norm(g)
- if(costScale_)
- gNorm = gNorm / costScaleFactor_;
+ // make gradient norm invariant to changes in costScalingFactor_, luckily norm(c * g) = c * norm(g)
+ if(costScaling_)
+ gNorm = gNorm / costScalingFactor_;
// Normalize gradient norm w.r.t. number of labels in batch for statistics,
// there should be no gradient normalization before this point, @TODO: check this
@@ -288,9 +286,7 @@ void GraphGroup::load(const OptimizerBase::ScatterStateFunc& scatterFn) {
restoreFromCheckpoint(modelFileName, scatterFn);
} else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string nameInit = options_->get<std::string>("pretrained-model");
- LOG(info,
- "[training] Initializing model weights with pre-trained model {}",
- nameInit);
+ LOG(info, "[training] Initializing model weights with pre-trained model {}", nameInit);
size_t i = 0;
for(auto graph : graphs_)
diff --git a/src/training/graph_group.h b/src/training/graph_group.h
index 422990b1..b7f2f7ef 100644
--- a/src/training/graph_group.h
+++ b/src/training/graph_group.h
@@ -60,22 +60,21 @@ protected:
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
- bool costScale_{false};
- float costScaleFactor_{1.f}; // @TODO, add current costScaleFactor_ to trainingState for serialization
- size_t costScaleFreq_{2000};
- float costScaleMultiplier_{2.f};
- float costScaleNanTolerance_{0.f};
- size_t costScaleNanRange_{1};
- float costScaleFactorMinimum_{1.f}; // @TODO make this configureable
+ bool costScaling_{false};
+ float costScalingFactor_{1.f}; // @TODO, add current costScalingFactor_ to trainingState for serialization
+ size_t costScalingFreq_{2000};
+ float costScalingMultiplier_{2.f};
+ float costScalingFactorMinimum_{1.f};
+
size_t noNanSeen_{0}; // @TODO, add current noNanSeen_ to trainingState for serialization
size_t nanSeen_{0};
+ bool checkGradientNan_{false};
+
bool dynamicGradientScaling_{false};
float dynamicGradientScalingFactor_{2.f};
bool dynamicGradientScalingUseLogs_{false};
- bool checkGradientNan_{false};
-
// determines the number of input streams (i.e. input files or fields in the TSV input) that need
// to be included in the batch, i.e. without alignments and weights
size_t numberOfInputFiles();
diff --git a/src/training/graph_group_async.cpp b/src/training/graph_group_async.cpp
index 72b06e48..f85f9cf8 100644
--- a/src/training/graph_group_async.cpp
+++ b/src/training/graph_group_async.cpp
@@ -143,13 +143,13 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
thread_local Tensor accGradients;
thread_local Ptr<TensorAllocator> accAlloc;
- ABORT_IF(costScale_ ,"Cost-scaling not implemented for AsyncSGD");
+ ABORT_IF(costScaling_ ,"Cost-scaling not implemented for AsyncSGD");
auto graph = graphs_[tid];
Ptr<RationalLoss> dynamicLoss = models_[tid]->build(graph, batch);
- if(costScaleFactor_ != 1.f) {
+ if(costScalingFactor_ != 1.f) {
// it's ok to go out of scope, this will still insert the new top node into the graph
- auto costNode = dynamicLoss->loss() * costScaleFactor_;
+ auto costNode = dynamicLoss->loss() * costScalingFactor_;
}
if(t % optimizerDelay_ == 0) {
diff --git a/src/training/graph_group_singleton.cpp b/src/training/graph_group_singleton.cpp
index 7dc86137..16261070 100644
--- a/src/training/graph_group_singleton.cpp
+++ b/src/training/graph_group_singleton.cpp
@@ -16,16 +16,16 @@ void SingletonGraph::execute(Ptr<data::Batch> batch) {
auto opt = optimizerShards_[0];
auto lossNode = model->build(graph, batch);
- if(costScaleFactor_ != 1.f) {
+ if(costScalingFactor_ != 1.f) {
// for fp16 training, it's ok to go out of scope, we do not use the scaled version for anything
- auto scaledLoss = lossNode->loss() * costScaleFactor_;
+ auto scaledLoss = lossNode->loss() * costScalingFactor_;
}
graph->forward();
graph->backward();
bool noNanOrInf = true;
- if(costScale_) {
+ if(costScaling_) {
// Are there NaNs in the gradient?
bool hasNan = false, hasInf = false;
IsNaN(graph->params()->grads(), graph->allocator(), hasNan, hasInf);
@@ -39,7 +39,7 @@ void SingletonGraph::execute(Ptr<data::Batch> batch) {
opt->update(graph->params()->vals(),
graph->params()->grads(),
batch->wordsTrg(),
- costScaleFactor_);
+ costScalingFactor_);
if(scheduler_) {
scheduler_->update(*lossNode, batch);
diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp
index 8c06761e..c90a384e 100644
--- a/src/training/graph_group_sync.cpp
+++ b/src/training/graph_group_sync.cpp
@@ -252,8 +252,8 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
{ // let loss go out of scope, frees memory
auto rationalLoss = models_[localDeviceIndex]->build(graph, subBatch);
- if(costScaleFactor_ != 1.f)
- rationalLoss->loss() * costScaleFactor_;
+ if(costScalingFactor_ != 1.f)
+ rationalLoss->loss() * costScalingFactor_;
graph->forward();
localDeviceLosses[localDeviceIndex] += *rationalLoss;
@@ -262,7 +262,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
graph->backward(/*zero=*/false); // (gradients are reset before we get here)
}
-#if 1
+#if 0 // @TODO: this can probably be removed now, keep around until confirmed.
// experimental and should eventually be somewhere else
// Handle local gradient explosion but only clip to largest possible value
// given number of GPUs and type. Should clip rarely. Also clips inf
@@ -284,7 +284,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices (globally) into shards
float gradNorm = 0.f;
- if(costScale_ || dynamicGradientScaling_ || checkGradientNan_) {
+ if(costScaling_ || dynamicGradientScaling_ || checkGradientNan_) {
// Wrapping member function
auto checkNanOrNorm = [&](size_t i, size_t begin, size_t end) {
return GraphGroup::checkNanOrNorm(i, begin, end);