diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-10-26 23:25:39 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-10-26 23:25:39 +0300 |
commit | 1404201926b5b4e27993776d52dfac809e8556f4 (patch) | |
tree | 10d4cda76a78a3a3f607b543fce6602367ab6487 | |
parent | 7f06f3c5d2035dac0cb4349bf29fbfa3e6bb5448 (diff) |
Merged PR 21151: Cleaning up fp16 behavior
This PR improves clipping and pruning behavior of NaNs and Infs during fp16 training, ultimately avoiding the underflow problems that we were facing so far.
-rw-r--r-- | src/common/aliases.cpp | 4 | ||||
-rw-r--r-- | src/common/config_parser.cpp | 6 | ||||
-rw-r--r-- | src/common/definitions.h | 10 | ||||
-rw-r--r-- | src/models/transformer.h | 15 | ||||
-rwxr-xr-x | src/tensors/cpu/tensor_operators.cpp | 4 | ||||
-rwxr-xr-x | src/tensors/gpu/element.cu | 12 | ||||
-rw-r--r-- | src/tensors/gpu/tensor_operators.cu | 147 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 19 | ||||
-rw-r--r-- | src/training/graph_group.cpp | 118 | ||||
-rw-r--r-- | src/training/graph_group.h | 17 | ||||
-rw-r--r-- | src/training/graph_group_async.cpp | 6 | ||||
-rw-r--r-- | src/training/graph_group_singleton.cpp | 8 | ||||
-rw-r--r-- | src/training/graph_group_sync.cpp | 8 |
13 files changed, 233 insertions, 141 deletions
diff --git a/src/common/aliases.cpp b/src/common/aliases.cpp index 0be26a8c..99574fe1 100644 --- a/src/common/aliases.cpp +++ b/src/common/aliases.cpp @@ -29,8 +29,8 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) { cli.alias("fp16", "true", [&](YAML::Node& config) { if(mode_ == cli::mode::training) { config["precision"] = std::vector<std::string>({"float16", "float32"}); // inference type, optimization type, save type - // scaling factor (power of 2), frequency, multiplier at increase, tolerance, range, minium factor - config["cost-scaling"] = std::vector<std::string>({"0", "1000", "2", "0.05", "10", "1e-5"}); + // scaling factor, frequency, multiplier at increase, minium scaling factor + config["cost-scaling"] = std::vector<std::string>({"256.f", "1000", "2.f", "256.f"}); } else { config["precision"] = std::vector<std::string>({"float16"}); // for inference we do not need the other types } diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index b3e8950b..51764cdc 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -522,15 +522,15 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { // mixed precision training cli.add<bool>("--fp16", "Shortcut for mixed precision training with float16 and cost-scaling, " - "corresponds to: --precision float16 float32 --cost-scaling 0 1000 2 0.05 10 1e-5f"); + "corresponds to: --precision float16 float32 --cost-scaling 256.f 1000 2.f 256.f"); cli.add<std::vector<std::string>>("--precision", "Mixed precision training for forward/backward pass and optimizaton. " "Defines types for: forward/backward pass, optimization.", {"float32", "float32"}); cli.add<std::vector<std::string>>("--cost-scaling", "Dynamic cost scaling for mixed precision training: " - "power of 2, scaling window, scaling factor, tolerance, range, minimum factor") - ->implicit_val("0.f 1000 2.f 0.05f 10 1e-5f"); + "scaling factor, frequency, multiplier, minimum factor") + ->implicit_val("256.f 1000 2.f 256.f"); cli.add<size_t>("--gradient-norm-average-window", "Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). " "After this many updates about 90% of the mass of the exponential average comes from these updates", diff --git a/src/common/definitions.h b/src/common/definitions.h index d2cf8aa4..d8a3ad46 100644 --- a/src/common/definitions.h +++ b/src/common/definitions.h @@ -106,24 +106,24 @@ using Weak = std::weak_ptr<T>; /** @brief Creates shared_ptr of any type, passes all arguments to any available * constructor */ template <class T, typename... Args> -Ptr<T> New(Args&&... args) { - return Ptr<T>(new T(std::forward<Args>(args)...)); +inline Ptr<T> New(Args&&... args) { + return std::make_shared<T>(std::forward<Args>(args)...); } template <class T> -Ptr<T> New(Ptr<T> p) { +inline Ptr<T> New(Ptr<T> p) { return Ptr<T>(p); } /** @brief Creates InstrusivePtr of any type, passes all arguments to any available * constructor */ template <class T, typename... Args> -IPtr<T> INew(Args&&... args) { +inline IPtr<T> INew(Args&&... args) { return IPtr<T>(new T(std::forward<Args>(args)...)); } template <class T> -IPtr<T> INew(Ptr<T> p) { +inline IPtr<T> INew(Ptr<T> p) { return IPtr<T>(p); } diff --git a/src/models/transformer.h b/src/models/transformer.h index 2393ad73..b2c0f6be 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -147,8 +147,7 @@ public: int dimDepth = dimModel / dimHeads; - auto output - = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth}); + auto output = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth}); return transpose(output, {0, 2, 1, 3}); // [dimBatch*dimBeam, dimHeads, dimSteps, dimDepth] } @@ -361,9 +360,9 @@ public: Expr LayerAttention(std::string prefix, Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] - const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] - const Expr& values, // ...? - const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] + Expr keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + Expr values, // ...? + Expr mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] int dimHeads, bool cache = false, bool saveAttentionWeights = false) { @@ -373,6 +372,12 @@ public: auto opsPre = opt<std::string>("transformer-preprocess"); auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb); + // fixes missing norm for keys and values in self-attention with pre-norm + if(input == keys) + keys = output; + if(input == values) + values = output; + // multi-head self-attention over previous input output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights); diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 1afb8f64..f3964f91 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -24,6 +24,10 @@ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, b ABORT("Not implemented"); } +bool SanitizeGradient(marian::Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool /*pruneNaN*/, bool /*clipInf*/) { + ABORT("Not implemented"); +} + template <bool add, typename To, typename From> void CopyCastTo(To* out, const From* in, int length) { for(int i = 0; i < length; ++i) diff --git a/src/tensors/gpu/element.cu b/src/tensors/gpu/element.cu index 6790efd4..e9cbe081 100755 --- a/src/tensors/gpu/element.cu +++ b/src/tensors/gpu/element.cu @@ -29,7 +29,9 @@ __global__ void gElement( indices[i] = tensors[i].shape().bindex(dims); } - tensors[0].data()[index] = functional::apply(functor, tensors, indices); + // This performs the internal application of the functor in float32 regardless of the input type. + // It seems there are no speed penalties but improved precision. + tensors[0].data()[index] = (T)functional::applyWithCast<float>(functor, tensors, indices); } } } @@ -65,13 +67,7 @@ void Element(Functor functor, Tensor out, Tensors... tensors) { ElementTyped<float>(functor, out, tensors...); } else if(out->type() == Type::float16) { #if COMPILE_FP16 - std::vector<marian::Tensor> ts({out, tensors...}); - bool div2 = std::all_of(ts.cbegin(), ts.cend(), [](marian::Tensor t){ return t->shape()[-1] % 2 == 0; }); - if(div2) { - ElementTyped<halfx2>(functor, out, tensors...); - } else { - ElementTyped<half>(functor, out, tensors...); - } + ElementTyped<half>(functor, out, tensors...); #else ABORT("FP16 not supported with chosen current hardware or CUDA version"); #endif diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index d55214bc..1347c3bb 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -16,15 +16,12 @@ namespace gpu { namespace atomics { static inline __device__ void atomicAdd(float *address, float val) { - //*address += val; ::atomicAdd(address, val); } #if COMPILE_FP16 // @TODO: copied from CuTorch, adapt this better, give credit. static inline __device__ void atomicAdd(half *address, half val) { - //*address += val; - #if __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000 // compute capability 70 and higher with CUDA 10 ::atomicAdd(address, val); #else // __CUDA_ARCH__ < 700 @@ -50,7 +47,8 @@ static inline __device__ void atomicAdd(half *address, half val) { } while (assumed != old); #endif // __CUDA_ARCH__ } -#endif +#endif // COMPILE_FP16 + } @@ -96,6 +94,81 @@ void IsNaN(const Tensor in, Ptr<Allocator> allocator, bool& isNaN, bool& isInf) cudaStreamSynchronize(0); } +template <typename T> +__global__ void gSanitizeGradient(T* in, int length, + bool* isNaN, bool* isInf, + bool pruneNaN, bool clipInf, + float forNaN = 0.f, float forInf = 65504.f, float forInfNeg = -65504.f) { + for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { + int index = bid + blockDim.x * blockIdx.x + threadIdx.x; + if(index < length) { + float v = (float)in[index]; + // handle NaN + if(isnan(v)) { + if(pruneNaN) { + in[index] = (T)forNaN; + } else { + *isNaN = true; + } + } + // handle +/- Inf + if(isinf(v)) { + if(clipInf) { + in[index] = v > 0 ? (T)forInf : (T)forInfNeg; + } else { + *isInf = true; + } + } + } + } +} + +// This function is meant to clean gradients, i.e. clip infinities and prune NaNs if required. +// If all NaNs and Infs have been removed we return `true` for indicating a sane gradient. +// If `clipInf` is set, infinities are replaced with the maximum/minimum non-inf value for the tensor. +// In that case infinities do not result in a bad gradient, since they get clipped. +// If `pruneNaN` is set, NaNs are replaced with 0. Since NaNs get removed now they do not result +// in a bad gradient. +// If NaNs or infinities are detected but not removed (either because of `pruneNaN=false` or `clipInf=false`), +// we return `false` indicating a bad gradient. +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) { + cudaSetDevice(in->getDeviceId().no); + + int length = in->size(); + + int threads = std::min(MAX_THREADS, length); + int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); + + auto mem = allocator->alloc<bool>(2); + bool* dIsNaN = &mem->data<bool>()[0]; + bool* dIsInf = &mem->data<bool>()[1]; + fill(in->getBackend(), dIsNaN, dIsNaN + 2, false); + + float forNaN = 0.f; + float forInf = NumericLimits<float>(in->type()).max; + float forInfNeg = NumericLimits<float>(in->type()).lowest; + + if(in->type() == Type::float32) { + gSanitizeGradient<<<blocks, threads>>>(in->data<float>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg); +#if COMPILE_FP16 + } else if(in->type() == Type::float16) { + gSanitizeGradient<<<blocks, threads>>>(in->data<half>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg); +#endif + } else { + ABORT("gSanitizeGradient for type {} not implemented", in->type()); + } + + bool isNaN, isInf; + CudaCopy(dIsNaN, dIsNaN + 1, &isNaN); + CudaCopy(dIsInf, dIsInf + 1, &isInf); + + allocator->free(mem); + + cudaStreamSynchronize(0); + + return !isNaN && !isInf; +} + template <bool add, typename To, typename From> __global__ void gCopyCastTo(To* out, const From* in, int length) { for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { @@ -1090,7 +1163,7 @@ void PasteRows(Tensor out, size_t rowsToCopy = indices->size(); int threads = std::min(MAX_THREADS, (int)cols); -#if 1 // @TODO: make this configurable with a 'deterministic' flag +#if 0 // @TODO: make this configurable with a 'deterministic' flag // If we only use one block, then each core operates on a different column, // hence the summation becomes deterministic. // However, we only use e.g. 512 cores out of possibly 3000+, so this will be @@ -1355,7 +1428,7 @@ __global__ void gGRUFastForward(T* out, for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { - T m = !mask || mask[j]; + float m = !mask || mask[j]; T* rowOut = out + j * cols; const T* rowState = state + j * cols; @@ -1365,21 +1438,21 @@ __global__ void gGRUFastForward(T* out, for(int tid = 0; tid < cols; tid += blockDim.x) { int i = tid + threadIdx.x; if(i < cols) { - T r = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]); + float r = functional::Ops<float>::sigmoid((float)xWrow[i] + (float)sUrow[i] + (float)b[i]); int k = i + cols; - T z = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]); + float z = functional::Ops<float>::sigmoid((float)xWrow[k] + (float)sUrow[k] + (float)b[k]); int l = i + 2 * cols; - T h; + float h; if(final) - h = functional::Ops<T>::tanh(xWrow[l] + (sUrow[l] + b[l]) * r); + h = functional::Ops<float>::tanh((float)xWrow[l] + ((float)sUrow[l] + (float)b[l]) * r); else - h = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] * r + b[l]); + h = functional::Ops<float>::tanh((float)xWrow[l] + (float)sUrow[l] * r + (float)b[l]); - T out = ((T)1.f - z) * h + z * rowState[i]; - rowOut[i] = m * out + ((T)1.f - m) * rowState[i]; + float out = (1.f - z) * h + z * (float)rowState[i]; + rowOut[i] = (T)(m * out + (1.f - m) * (float)rowState[i]); } } } @@ -1441,7 +1514,7 @@ __global__ void gGRUFastBackward(T* outState, for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { - T m = !mask || mask[j]; + float m = !mask || mask[j]; T* rowOutState = outState + j * cols; T* rowOutXW = outXW + j * cols * 3; @@ -1459,56 +1532,56 @@ __global__ void gGRUFastBackward(T* outState, int k = i + cols; int l = i + 2 * cols; - T r = functional::Ops<T>::sigmoid(rowXW[i] + rowSU[i] + b[i]); - T z = functional::Ops<T>::sigmoid(rowXW[k] + rowSU[k] + b[k]); + float r = functional::Ops<float>::sigmoid((float)rowXW[i] + (float)rowSU[i] + (float)b[i]); + float z = functional::Ops<float>::sigmoid((float)rowXW[k] + (float)rowSU[k] + (float)b[k]); - T h; + float h; if(final) - h = functional::Ops<T>::tanh(rowXW[l] + (rowSU[l] + b[l]) * r); + h = functional::Ops<float>::tanh((float)rowXW[l] + ((float)rowSU[l] + (float)b[l]) * r); else - h = functional::Ops<T>::tanh(rowXW[l] + rowSU[l] * r + b[l]); + h = functional::Ops<float>::tanh((float)rowXW[l] + (float)rowSU[l] * r + (float)b[l]); - T adj = rowAdj[i]; + float adj = rowAdj[i]; - T t = ((T)1.f - z) * ((T)1.f - h * h); + float t = (1.f - z) * (1.f - h * h); // df/ds if(outState) - rowOutState[i] += (m * z - m + (T)1.f) * adj; + rowOutState[i] += (T)((m * z - m + 1.f) * adj); // df/d(xW_r) ... - T dfdxW_r = m * r * ((T)1.f - r) * t * adj; + float dfdxW_r = m * r * (1.f - r) * t * adj; if(final) - dfdxW_r *= rowSU[l] + b[l]; + dfdxW_r *= (float)rowSU[l] + (float)b[l]; else - dfdxW_r *= rowSU[l]; + dfdxW_r *= (float)rowSU[l]; if(outXW) - rowOutXW[i] += dfdxW_r; + rowOutXW[i] += (T)dfdxW_r; if(outSU) - rowOutSU[i] += dfdxW_r; + rowOutSU[i] += (T)dfdxW_r; if(outB) - rowOutB[i] += dfdxW_r; + rowOutB[i] += (T)dfdxW_r; // df/d(xW_z) ... - T dfdxW_z = m * ((T)1.f - z) * z * (rowState[i] - h) * adj; + float dfdxW_z = m * (1.f - z) * z * ((float)rowState[i] - h) * adj; if(outXW) - rowOutXW[k] += dfdxW_z; + rowOutXW[k] += (T)dfdxW_z; if(outSU) - rowOutSU[k] += dfdxW_z; + rowOutSU[k] += (T)dfdxW_z; if(outB) - rowOutB[k] += dfdxW_z; + rowOutB[k] += (T)dfdxW_z; // df/d(xW_x) ... - T dfdxW_x = m * t * adj; + float dfdxW_x = m * t * adj; if(outXW) - rowOutXW[l] += dfdxW_x; + rowOutXW[l] += (T)dfdxW_x; if(outSU) - rowOutSU[l] += dfdxW_x * r; + rowOutSU[l] += (T)(dfdxW_x * r); if(outB) if(final) - rowOutB[l] += dfdxW_x * r; + rowOutB[l] += (T)(dfdxW_x * r); else - rowOutB[l] += dfdxW_x; + rowOutB[l] += (T)dfdxW_x; } } } diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 6e587953..dc29bf35 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -41,6 +41,25 @@ DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor); DISPATCH2(AddCast, marian::Tensor, const marian::Tensor); DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&); +#ifdef CUDA_FOUND +namespace gpu { +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf); +} +#endif + +namespace cpu { +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf); +} + +static inline bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) { +#ifdef CUDA_FOUND + if(in->getBackend()->getDeviceId().type == DeviceType::gpu) + return gpu::SanitizeGradient(in, allocator, pruneNaN, clipInf); + else +#endif + return cpu::SanitizeGradient(in, allocator, pruneNaN, clipInf); +} + template <class Functor, class... Tensors> void Element(Functor functor, marian::Tensor out, Tensors... tensors) { #ifdef CUDA_FOUND diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp index e9c977b9..03e5acf4 100644 --- a/src/training/graph_group.cpp +++ b/src/training/graph_group.cpp @@ -10,25 +10,19 @@ GraphGroup::GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi) mbRoundUp_(options_->get<bool>("mini-batch-round-up", true)) { if(options_->hasAndNotEmpty("cost-scaling")) { auto vcs = options_->get<std::vector<std::string>>("cost-scaling"); - costScale_ = true; - float costExponent = std::stof(vcs[0]); - costScaleFactor_ = std::pow(2.0f, costExponent); - - if(vcs.size() > 1) costScaleFreq_ = std::stoul(vcs[1]); - if(vcs.size() > 2) costScaleMultiplier_ = std::stof(vcs[2]); - if(vcs.size() > 3) costScaleNanTolerance_ = std::stof(vcs[3]); - if(vcs.size() > 4) costScaleNanRange_ = std::stoul(vcs[4]); - if(vcs.size() > 5) costScaleFactorMinimum_ = std::stof(vcs[5]); + + costScaling_ = true; + costScalingFactor_ = std::stof( vcs[0]); + if(vcs.size() > 1) costScalingFreq_ = std::stoul(vcs[1]); + if(vcs.size() > 2) costScalingMultiplier_ = std::stof( vcs[2]); + if(vcs.size() > 3) costScalingFactorMinimum_ = std::stof( vcs[3]); LOG_ONCE(info, - "Training with cost scaling - factor: 2^{} = {}, frequency: {}, multiplier: {}, tolerance: {}, range: {}, minimum: {}", - costExponent, - costScaleFactor_, - costScaleFreq_, - costScaleMultiplier_, - costScaleNanTolerance_, - costScaleNanRange_, - costScaleFactorMinimum_); + "Training with cost scaling - factor: {}, frequency: {}, multiplier: {}, minimum: {}", + costScalingFactor_, + costScalingFreq_, + costScalingMultiplier_, + costScalingFactorMinimum_); } if(options_->hasAndNotEmpty("dynamic-gradient-scaling")) { @@ -96,21 +90,17 @@ void GraphGroup::initGraphsAndOpts() { // given number of iterations. Usually we increase by 2 which adds // one more bit for precision. void GraphGroup::increaseCostScaleFactor() { - if(!costScale_) + if(!costScaling_) return; noNanSeen_++; size_t total = nanSeen_ + noNanSeen_; - float nanPercent = noNanSeen_ == (float)nanSeen_ / (float)total; // total is at least 1 because of noNanSeen_++ - if(noNanSeen_ % costScaleFreq_ == 0) { - costScaleFactor_ *= costScaleMultiplier_; - LOG(debug, - "NaN/Inf percentage {:.2f} after {} gradient updates. Increasing cost-scaling factor to {}", - nanPercent, - total, - costScaleFactor_); + if(noNanSeen_ % costScalingFreq_ == 0) { + costScalingFactor_ *= costScalingMultiplier_; + if(isMainProcess()) + LOG(debug, "No NaN/Inf after {} gradient updates. Increasing cost-scaling factor to {}", total, costScalingFactor_); // Resetting counts after cost-scale change noNanSeen_ = 0; @@ -120,48 +110,56 @@ void GraphGroup::increaseCostScaleFactor() { // call when a NaN was seen to decrease cost-scaling factor void GraphGroup::decreaseCostScaleFactor() { - if(!costScale_) + if(!costScaling_) return; nanSeen_++; size_t total = nanSeen_ + noNanSeen_; - float nanPercent = (float)nanSeen_ / (float)total; // total is at least 1 because of nanSeen_++ - if(total >= costScaleNanRange_ && nanPercent > costScaleNanTolerance_) { - if(costScaleFactor_ > costScaleFactorMinimum_) { - costScaleFactor_ /= costScaleMultiplier_; - LOG(debug, - "NaN/Inf percentage {:.2f} in {} gradient updates, reducing cost-scaling factor to {}", - nanPercent, - total, - costScaleFactor_); - } else { - // @TODO: think if should this rather abort? - LOG(warn, - "NaN/Inf percentage {:.2f} in {} gradient updates, but cost-scaling factor {} is already at minimum", - nanPercent, - total, - costScaleFactor_); - } - // Resetting counts after cost-scale change - noNanSeen_ = 0; - nanSeen_ = 0; + // do not reduce cost-scaling factor below minimum + if(costScalingFactor_ > costScalingFactorMinimum_) + costScalingFactor_ /= costScalingMultiplier_; + + if(isMainProcess()) { + if(costScalingFactor_ > costScalingFactorMinimum_) + LOG(debug, "Seen NaN/Inf after {} gradient updates. Reduced cost-scaling factor to {}", total, costScalingFactor_); + else + LOG(debug, "Seen NaN/Inf after {} gradient updates, Reduced cost-scaling factor to minimum {}. Pruning NaNs now.", total, costScalingFactor_); } + + // Resetting counts after cost-scale change + noNanSeen_ = 0; + nanSeen_ = 0; } float GraphGroup::checkNanOrNorm(size_t i, size_t begin, size_t end) { auto curGrad = graphs_[i]->params()->grads()->subtensor(begin, end-begin); - if(checkGradientNan_ || costScale_) { - bool hasNan = false, hasInf = false; - IsNaN(curGrad, graphs_[i]->allocator(), hasNan, hasInf); // @TODO: make safe with different compiler options - if(hasNan || hasInf) { - LOG(debug, "Found Nan ({}) or Inf ({})", hasNan, hasInf); + // If costScaling_ then check for NaN values if the costScalingFactor_ is larger than + // the minimum. If a NaN value is seen we exit here and will reduce the factor next and + // this skips an update. + // If costScalingFactor_ is already at the minimum, prune the NaN values away. This replaces + // NaNs with 0. Updates are not skipped any more. + // Regardless of NaNs, we clip +/-inf to the largest corresponding values for the gradient value type. + // This changes the gradient but seems to be quite stable. In effect, for fp16 this is equivalent + // to gradient clipping at (65504.f / costScalingFactor_) which in most cases is still large. + if(costScaling_ || checkGradientNan_) { + bool pruneNaN = !checkGradientNan_ && costScalingFactor_ == costScalingFactorMinimum_; + bool clipInf = !checkGradientNan_; + bool saneGradient = SanitizeGradient(curGrad, graphs_[i]->allocator(), pruneNaN, clipInf); + + // This should never happen, if it does, something is wrong with the kernel above and needs to be fixed. + ABORT_IF(pruneNaN && clipInf && !saneGradient, "We are removing NaNs and clipping Infs, but gradient is still not sane??"); + + if(!saneGradient) { + LOG(debug, "Found NaN"); return std::numeric_limits<float>::quiet_NaN(); } } - + + // The optional clipping above will affect the norm here. The norm can be non-finite despite the above + // gradient sanitization, hence check again and propagate a NaN. if(dynamicGradientScaling_) { auto gNorm = L2Norm(curGrad, graphs_[i]->allocator()); if(isFinite(gNorm) && gNorm > 0.0) @@ -197,8 +195,8 @@ float GraphGroup::executeAndCollectNorm(const std::function<float(size_t, size_t float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords) { float normalizationFactor = 1.f; - if(costScale_) - normalizationFactor *= costScaleFactor_; + if(costScaling_) + normalizationFactor *= costScalingFactor_; if(options_->get<bool>("normalize-gradient")) normalizationFactor *= updateTrgWords; @@ -207,9 +205,9 @@ float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords) return normalizationFactor; if(dynamicGradientScaling_) { - // make gradient norm invariant to changes in costScaleFactor_, luckily norm(c * g) = c * norm(g) - if(costScale_) - gNorm = gNorm / costScaleFactor_; + // make gradient norm invariant to changes in costScalingFactor_, luckily norm(c * g) = c * norm(g) + if(costScaling_) + gNorm = gNorm / costScalingFactor_; // Normalize gradient norm w.r.t. number of labels in batch for statistics, // there should be no gradient normalization before this point, @TODO: check this @@ -288,9 +286,7 @@ void GraphGroup::load(const OptimizerBase::ScatterStateFunc& scatterFn) { restoreFromCheckpoint(modelFileName, scatterFn); } else if(options_->hasAndNotEmpty("pretrained-model")) { std::string nameInit = options_->get<std::string>("pretrained-model"); - LOG(info, - "[training] Initializing model weights with pre-trained model {}", - nameInit); + LOG(info, "[training] Initializing model weights with pre-trained model {}", nameInit); size_t i = 0; for(auto graph : graphs_) diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 422990b1..b7f2f7ef 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -60,22 +60,21 @@ protected: double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false - bool costScale_{false}; - float costScaleFactor_{1.f}; // @TODO, add current costScaleFactor_ to trainingState for serialization - size_t costScaleFreq_{2000}; - float costScaleMultiplier_{2.f}; - float costScaleNanTolerance_{0.f}; - size_t costScaleNanRange_{1}; - float costScaleFactorMinimum_{1.f}; // @TODO make this configureable + bool costScaling_{false}; + float costScalingFactor_{1.f}; // @TODO, add current costScalingFactor_ to trainingState for serialization + size_t costScalingFreq_{2000}; + float costScalingMultiplier_{2.f}; + float costScalingFactorMinimum_{1.f}; + size_t noNanSeen_{0}; // @TODO, add current noNanSeen_ to trainingState for serialization size_t nanSeen_{0}; + bool checkGradientNan_{false}; + bool dynamicGradientScaling_{false}; float dynamicGradientScalingFactor_{2.f}; bool dynamicGradientScalingUseLogs_{false}; - bool checkGradientNan_{false}; - // determines the number of input streams (i.e. input files or fields in the TSV input) that need // to be included in the batch, i.e. without alignments and weights size_t numberOfInputFiles(); diff --git a/src/training/graph_group_async.cpp b/src/training/graph_group_async.cpp index 72b06e48..f85f9cf8 100644 --- a/src/training/graph_group_async.cpp +++ b/src/training/graph_group_async.cpp @@ -143,13 +143,13 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) { thread_local Tensor accGradients; thread_local Ptr<TensorAllocator> accAlloc; - ABORT_IF(costScale_ ,"Cost-scaling not implemented for AsyncSGD"); + ABORT_IF(costScaling_ ,"Cost-scaling not implemented for AsyncSGD"); auto graph = graphs_[tid]; Ptr<RationalLoss> dynamicLoss = models_[tid]->build(graph, batch); - if(costScaleFactor_ != 1.f) { + if(costScalingFactor_ != 1.f) { // it's ok to go out of scope, this will still insert the new top node into the graph - auto costNode = dynamicLoss->loss() * costScaleFactor_; + auto costNode = dynamicLoss->loss() * costScalingFactor_; } if(t % optimizerDelay_ == 0) { diff --git a/src/training/graph_group_singleton.cpp b/src/training/graph_group_singleton.cpp index 7dc86137..16261070 100644 --- a/src/training/graph_group_singleton.cpp +++ b/src/training/graph_group_singleton.cpp @@ -16,16 +16,16 @@ void SingletonGraph::execute(Ptr<data::Batch> batch) { auto opt = optimizerShards_[0]; auto lossNode = model->build(graph, batch); - if(costScaleFactor_ != 1.f) { + if(costScalingFactor_ != 1.f) { // for fp16 training, it's ok to go out of scope, we do not use the scaled version for anything - auto scaledLoss = lossNode->loss() * costScaleFactor_; + auto scaledLoss = lossNode->loss() * costScalingFactor_; } graph->forward(); graph->backward(); bool noNanOrInf = true; - if(costScale_) { + if(costScaling_) { // Are there NaNs in the gradient? bool hasNan = false, hasInf = false; IsNaN(graph->params()->grads(), graph->allocator(), hasNan, hasInf); @@ -39,7 +39,7 @@ void SingletonGraph::execute(Ptr<data::Batch> batch) { opt->update(graph->params()->vals(), graph->params()->grads(), batch->wordsTrg(), - costScaleFactor_); + costScalingFactor_); if(scheduler_) { scheduler_->update(*lossNode, batch); diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index 8c06761e..c90a384e 100644 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp @@ -252,8 +252,8 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num { // let loss go out of scope, frees memory auto rationalLoss = models_[localDeviceIndex]->build(graph, subBatch); - if(costScaleFactor_ != 1.f) - rationalLoss->loss() * costScaleFactor_; + if(costScalingFactor_ != 1.f) + rationalLoss->loss() * costScalingFactor_; graph->forward(); localDeviceLosses[localDeviceIndex] += *rationalLoss; @@ -262,7 +262,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num graph->backward(/*zero=*/false); // (gradients are reset before we get here) } -#if 1 +#if 0 // @TODO: this can probably be removed now, keep around until confirmed. // experimental and should eventually be somewhere else // Handle local gradient explosion but only clip to largest possible value // given number of GPUs and type. Should clip rarely. Also clips inf @@ -284,7 +284,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices (globally) into shards float gradNorm = 0.f; - if(costScale_ || dynamicGradientScaling_ || checkGradientNan_) { + if(costScaling_ || dynamicGradientScaling_ || checkGradientNan_) { // Wrapping member function auto checkNanOrNorm = [&](size_t i, size_t begin, size_t end) { return GraphGroup::checkNanOrNorm(i, begin, end); |