Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rogrundk@microsoft.com>2021-04-26 14:51:43 +0300
committerRoman Grundkiewicz <rogrundk@microsoft.com>2021-04-26 14:51:43 +0300
commit49e379bba5c77c1b80927b7f0db5603e171a1903 (patch)
tree7cd89540ba86333eff3d4f07a3ab0b8ef3db324f /src
parent3e51ff387232f1096e9560980f0115ac734224f5 (diff)
Merged PR 18612: Early stopping on first, all, or any validation metrics
Adds `--early-stopping-on first|all|any` allowing to decide if early stopping should take into account only first, all, or any validation metrics. Feature request: https://github.com/marian-nmt/marian-dev/issues/850 Regression tests: https://github.com/marian-nmt/marian-regression-tests/pull/79
Diffstat (limited to 'src')
-rw-r--r--src/common/config_parser.cpp14
-rw-r--r--src/common/config_validator.cpp9
-rw-r--r--src/training/scheduler.h69
-rw-r--r--src/training/training_state.h4
-rw-r--r--src/training/validator.cpp14
-rw-r--r--src/training/validator.h2
6 files changed, 72 insertions, 40 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 6495db0e..f29b3630 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -244,7 +244,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Tie all embedding layers and output layer");
cli.add<bool>("--output-omit-bias",
"Do not use a bias vector in decoder output layer");
-
+
// Transformer options
cli.add<int>("--transformer-heads",
"Number of heads in multi-head attention (transformer)",
@@ -529,13 +529,13 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). "
"After this many updates about 90% of the mass of the exponential average comes from these updates",
100);
- cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
+ cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
"Re-scale gradient to have average gradient norm if (log) gradient norm diverges from average by arg1 sigmas. "
"If arg2 = \"log\" the statistics are recorded for the log of the gradient norm else use plain norm")
->implicit_val("2.f log");
- cli.add<bool>("--check-gradient-nan",
+ cli.add<bool>("--check-gradient-nan",
"Skip parameter update in case of NaNs in gradient");
- cli.add<bool>("--normalize-gradient",
+ cli.add<bool>("--normalize-gradient",
"Normalize gradient by multiplying with no. devices / total labels (not recommended and to be removed in the future)");
cli.add<std::vector<std::string>>("--train-embedder-rank",
@@ -574,6 +574,10 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);
+ cli.add<std::string>("--early-stopping-on",
+ "Decide if early stopping should take into account first, all, or any validation metrics"
+ "Possible values: first, all, any",
+ "first");
// decoding options
cli.add<size_t>("--beam-size,-b",
@@ -586,7 +590,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Maximum target length as source length times factor",
3);
cli.add<float>("--word-penalty",
- "Subtract (arg * translation length) from translation score ");
+ "Subtract (arg * translation length) from translation score");
cli.add<bool>("--allow-unk",
"Allow unknown words to appear in output");
cli.add<bool>("--n-best",
diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp
index b2400145..fea7578f 100644
--- a/src/common/config_validator.cpp
+++ b/src/common/config_validator.cpp
@@ -4,6 +4,8 @@
#include "common/utils.h"
#include "common/filesystem.h"
+#include <set>
+
namespace marian {
bool ConfigValidator::has(const std::string& key) const {
@@ -129,6 +131,11 @@ void ConfigValidator::validateOptionsTraining() const {
&& !get<std::vector<std::string>>("valid-sets").empty(),
errorMsg);
+ // check if --early-stopping-on has proper value
+ std::set<std::string> supportedStops = {"first", "all", "any"};
+ ABORT_IF(supportedStops.find(get<std::string>("early-stopping-on")) == supportedStops.end(),
+ "Supported options for --early-stopping-on are: first, all, any");
+
// validations for learning rate decaying
ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
@@ -145,7 +152,7 @@ void ConfigValidator::validateOptionsTraining() const {
// validate ULR options
ABORT_IF((has("ulr") && get<bool>("ulr") && (get<std::string>("ulr-query-vectors") == ""
|| get<std::string>("ulr-keys-vectors") == "")),
- "ULR enablign requires query and keys vectors specified with --ulr-query-vectors and "
+ "ULR requires query and keys vectors specified with --ulr-query-vectors and "
"--ulr-keys-vectors option");
// validate model quantization
diff --git a/src/training/scheduler.h b/src/training/scheduler.h
index 9d2500f9..8d4fa30c 100644
--- a/src/training/scheduler.h
+++ b/src/training/scheduler.h
@@ -28,7 +28,7 @@ private:
// (regardless if it's the 1st or nth epoch and if it's a new or continued training),
// which indicates the end of the training data stream from STDIN
bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN;
-
+
// @TODO: figure out how to compute this with regard to updates as well, although maybe harder since no final value
// determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
float getScheduledLRDecayFactor(const TrainingState& state) const {
@@ -133,7 +133,7 @@ public:
Scheduler(Ptr<Options> options, Ptr<TrainingState> state, Ptr<IMPIWrapper> mpi = nullptr)
: options_(options), state_(state), mpi_(mpi),
gradientNormAvgWindow_(options_->get<size_t>("gradient-norm-average-window", 100)) {
-
+
// parse logical-epoch parameters
auto logicalEpochStr = options->get<std::vector<std::string>>("logical-epoch", {"1e", "0"});
ABORT_IF(logicalEpochStr.empty(), "Logical epoch information is missing?");
@@ -174,7 +174,7 @@ public:
size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
// if unit is labels, then account for the fact that our increment itself is not constant
-#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
+#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
if (mbWarmup.unit == SchedulingUnit::trgLabels)
progressRatio = std::sqrt(progressRatio);
#endif
@@ -207,7 +207,7 @@ public:
if(saveAndExitRequested()) // via SIGTERM
return false;
-#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
+#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
// stop if it reached the maximum number of epochs
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
if(stopAfterEpochs > 0 && calculateLogicalEpoch() > stopAfterEpochs)
@@ -231,10 +231,9 @@ public:
}
}
- // stop if the first validator did not improve for a given number of checks
+ // stop if the first/all/any validators did not improve for a given number of checks
size_t stopAfterStalled = options_->get<size_t>("early-stopping");
- if(stopAfterStalled > 0 && !validators_.empty()
- && stalled() >= stopAfterStalled)
+ if(stopAfterStalled > 0 && stalled() >= stopAfterStalled)
return false;
// stop if data streaming from STDIN is stopped
@@ -297,12 +296,11 @@ public:
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
return;
- bool firstValidator = true;
+ size_t stalledPrev = stalled();
for(auto validator : validators_) {
if(!validator)
continue;
- size_t stalledPrev = validator->stalled();
float value = 0;
if(!mpi_ || mpi_->isMainProcess()) {
// We run validation only in the main process, but this is risky with MPI.
@@ -330,34 +328,60 @@ public:
if(mpi_) {
// collect and broadcast validation result to all processes and bring validator up-to-date
mpi_->bCast(&value, 1, IMPIWrapper::getDataType(&value));
-
+
// @TODO: add function to validator?
mpi_->bCast(&validator->stalled(), 1, IMPIWrapper::getDataType(&validator->stalled()));
mpi_->bCast(&validator->lastBest(), 1, IMPIWrapper::getDataType(&validator->lastBest()));
}
- if(firstValidator)
- state_->validBest = value;
-
state_->validators[validator->type()]["last-best"] = validator->lastBest();
state_->validators[validator->type()]["stalled"] = validator->stalled();
-
- // notify training observers if the first validator did not improve
- if(firstValidator && validator->stalled() > stalledPrev)
- state_->newStalled(validator->stalled());
- firstValidator = false;
}
+ // notify training observers about stalled validation
+ size_t stalledNew = stalled();
+ if(stalledNew > stalledPrev)
+ state_->newStalled(stalledNew);
+
state_->validated = true;
}
+ // Returns the proper number of stalled validation w.r.t. early-stopping-on
size_t stalled() {
+ std::string stopOn = options_->get<std::string>("early-stopping-on");
+ if(stopOn == "any")
+ return stalledMax();
+ if(stopOn == "all")
+ return stalledMin();
+ return stalled1st();
+ }
+
+ // Returns the number of stalled validations for the first validator
+ size_t stalled1st() {
if(!validators_.empty())
if(validators_[0])
return validators_[0]->stalled();
return 0;
}
+ // Returns the largest number of stalled validations across validators or 0 if there are no validators
+ size_t stalledMax() {
+ size_t max = 0;
+ for(auto validator : validators_)
+ if(validator && validator->stalled() > max)
+ max = validator->stalled();
+ return max;
+ }
+
+ // Returns the lowest number of stalled validations across validators or 0 if there are no validators
+ size_t stalledMin() {
+ size_t min = std::numeric_limits<std::size_t>::max();
+ for(auto validator : validators_)
+ if(validator && validator->stalled() < min)
+ min = validator->stalled();
+ return min == std::numeric_limits<std::size_t>::max() ? 0 : min;
+ }
+
void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg(), /*gradientNorm=*/0.f);
}
@@ -397,8 +421,8 @@ public:
if(gradientNorm) {
size_t range = std::min(gradientNormAvgWindow_, state_->batches);
- float alpha = 2.f / (float)(range + 1);
-
+ float alpha = 2.f / (float)(range + 1);
+
float delta = gradientNorm - state_->gradientNormAvg;
state_->gradientNormAvg = state_->gradientNormAvg + alpha * delta;
state_->gradientNormVar = (1.0f - alpha) * (state_->gradientNormVar + alpha * delta * delta);
@@ -440,7 +464,7 @@ public:
formatLogicalEpoch(),
state_->batches,
utils::withCommas(state_->samplesEpoch),
- formatLoss(lossType, dispLabelCounts, batchLabels, state_),
+ formatLoss(lossType, dispLabelCounts, batchLabels, state_),
timer_.elapsed(),
state_->wordsDisp / timer_.elapsed(),
state_->gradientNormAvg);
@@ -627,7 +651,8 @@ public:
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
- state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
+ state.warmupStart.n = state.getProgressIn(
+ SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
}
diff --git a/src/training/training_state.h b/src/training/training_state.h
index 7d62f060..e0c1ba5d 100644
--- a/src/training/training_state.h
+++ b/src/training/training_state.h
@@ -43,8 +43,6 @@ public:
size_t stalled{0};
// The largest number of stalled validations so far
size_t maxStalled{0};
- // Last best validation score
- float validBest{0.f};
std::string validator;
// List of validators
YAML::Node validators;
@@ -217,7 +215,6 @@ public:
stalled = config["stalled"].as<size_t>();
maxStalled = config["stalled-max"].as<size_t>();
- validBest = config["valid-best"].as<float>();
validator = config["validator"].as<std::string>();
validators = config["validators"];
reset = config["reset"].as<bool>();
@@ -259,7 +256,6 @@ public:
config["stalled"] = stalled;
config["stalled-max"] = maxStalled;
- config["valid-best"] = validBest;
config["validator"] = validator;
config["validators"] = validators;
config["reset"] = reset;
diff --git a/src/training/validator.cpp b/src/training/validator.cpp
index d824052f..ef1bac3d 100644
--- a/src/training/validator.cpp
+++ b/src/training/validator.cpp
@@ -447,7 +447,7 @@ SacreBleuValidator::SacreBleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Optio
ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check.
- if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
+ if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
order_ = 6; // we compute stats over character ngrams up to length 6
// @TODO: remove, only used for saving?
@@ -613,12 +613,12 @@ void SacreBleuValidator::updateStats(std::vector<float>& stats,
LOG_VALID_ONCE(info, "First sentence's tokens as scored:");
LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false)));
LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false)));
-
+
if(useWordIds_)
updateStats(stats, cand, ref);
else
updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false));
-
+
}
// Re-implementation of BLEU metric from SacreBLEU
@@ -627,7 +627,7 @@ float SacreBleuValidator::calcBLEU(const std::vector<float>& stats) {
for(int i = 0; i < order_; ++i) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
-
+
if(commonNgrams == 0.f)
return 0.f;
logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams);
@@ -653,7 +653,7 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
float referencesNgrams = stats[statsPerOrder * i + 2];
-
+
if(hypothesesNgrams > 0 && referencesNgrams > 0) {
avgPrecision += commonNgrams / hypothesesNgrams;
avgRecall += commonNgrams / referencesNgrams;
@@ -666,10 +666,10 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
avgPrecision /= effectiveOrder;
avgRecall /= effectiveOrder;
-
+
if(avgPrecision + avgRecall == 0.f)
return 0.f;
-
+
auto betaSquare = beta * beta;
auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall);
return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU
diff --git a/src/training/validator.h b/src/training/validator.h
index d6e64d69..16bfd245 100644
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -352,7 +352,7 @@ protected:
private:
const std::string metric_; // allowed values are: bleu, bleu-detok (same as bleu), bleu-segmented, chrf
bool computeChrF_{ false }; // should we compute ChrF instead of BLEU (BLEU by default)?
-
+
size_t order_{ 4 }; // 4-grams for BLEU by default
static const size_t statsPerOrder = 3; // 0: common ngrams, 1: candidate ngrams, 2: reference ngrams
bool useWordIds_{ false }; // compute BLEU score by matching numeric segment ids