diff options
author | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-03-22 18:58:04 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-22 18:58:04 +0300 |
commit | 0394d2cdbe42b7ad7fc02854da8e063f2a23496f (patch) | |
tree | af669285a1906ba07e0397bd19e1ef33ae9b633a /src | |
parent | 096c48e51cd2e61bb275345d7cca99cbfd6bc5c7 (diff) |
Display decoder speed statistics with --stat-freq N (#841)
Display decoder time statistics if requested
Diffstat (limited to 'src')
-rw-r--r-- | src/common/config_parser.cpp | 3 | ||||
-rw-r--r-- | src/common/scheduling_parameter.h | 53 | ||||
-rw-r--r-- | src/training/training_state.h | 44 | ||||
-rw-r--r-- | src/translator/translator.h | 80 |
4 files changed, 126 insertions, 54 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 3baa13ea..602509c5 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -658,6 +658,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { ->implicit_val("1"); cli.add<bool>("--word-scores", "Print word-level scores. One score per subword unit, not normalized even if --normalize"); + cli.add<std::string/*SchedulerPeriod*/>("--stat-freq", + "Display speed information every arg mini-batches. Disabled by default with 0, set to value larger than 0 to activate", + "0"); #ifdef USE_SENTENCEPIECE cli.add<bool>("--no-spm-decode", "Keep the output segmented into SentencePiece subwords"); diff --git a/src/common/scheduling_parameter.h b/src/common/scheduling_parameter.h new file mode 100644 index 00000000..b6e39acf --- /dev/null +++ b/src/common/scheduling_parameter.h @@ -0,0 +1,53 @@ +#pragma once + +#include "common/logging.h" +#include "common/utils.h" + +#include <string> + +namespace marian { + +// support for scheduling parameters that can be expressed with a unit, such as --lr-decay-inv-sqrt +enum class SchedulingUnit { + trgLabels, // "t": number of target labels seen so far + updates, // "u": number of updates so far (batches) + epochs // "e": number of epochs begun so far (very first epoch is 1) +}; + +struct SchedulingParameter { + size_t n{0}; // number of steps measured in 'unit' + SchedulingUnit unit{SchedulingUnit::updates}; // unit of value + + // parses scheduling parameters of the form NU where N=unsigned int and U=unit + // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels), + // "100e" (100 epochs). + static SchedulingParameter parse(std::string param) { + SchedulingParameter res; + if(!param.empty() && param.back() >= 'a') { + switch(param.back()) { + case 't': res.unit = SchedulingUnit::trgLabels; break; + case 'u': res.unit = SchedulingUnit::updates; break; + case 'e': res.unit = SchedulingUnit::epochs; break; + default: ABORT("invalid unit '{}' in {}", param.back(), param); + } + param.pop_back(); + } + double number = utils::parseNumber(param); + res.n = (size_t)number; + ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers"); // @TODO: do they? + return res; + } + + operator bool() const { return n > 0; } // check whether it is specified + + operator std::string() const { // convert back for storing in config + switch(unit) { + case SchedulingUnit::trgLabels: return std::to_string(n) + "t"; + case SchedulingUnit::updates : return std::to_string(n) + "u"; + case SchedulingUnit::epochs : return std::to_string(n) + "e"; + default: ABORT("corrupt enum value for scheduling unit"); + } + } +}; + +}
\ No newline at end of file diff --git a/src/training/training_state.h b/src/training/training_state.h index 459e33cf..7d62f060 100644 --- a/src/training/training_state.h +++ b/src/training/training_state.h @@ -2,6 +2,7 @@ #include "common/definitions.h" #include "common/filesystem.h" +#include "common/scheduling_parameter.h" #include "common/utils.h" #include <fstream> @@ -22,49 +23,6 @@ public: virtual void actAfterLoaded(TrainingState&) {} }; -// support for scheduling parameters that can be expressed with a unit, such as --lr-decay-inv-sqrt -enum class SchedulingUnit { - trgLabels, // "t": number of target labels seen so far - updates, // "u": number of updates so far (batches) - epochs // "e": number of epochs begun so far (very first epoch is 1) -}; - -struct SchedulingParameter { - size_t n{0}; // number of steps measured in 'unit' - SchedulingUnit unit{SchedulingUnit::updates}; // unit of value - - // parses scheduling parameters of the form NU where N=unsigned int and U=unit - // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels), - // "100e" (100 epochs). - static SchedulingParameter parse(std::string param) { - SchedulingParameter res; - if(!param.empty() && param.back() >= 'a') { - switch(param.back()) { - case 't': res.unit = SchedulingUnit::trgLabels; break; - case 'u': res.unit = SchedulingUnit::updates; break; - case 'e': res.unit = SchedulingUnit::epochs; break; - default: ABORT("invalid unit '{}' in {}", param.back(), param); - } - param.pop_back(); - } - double number = utils::parseNumber(param); - res.n = (size_t)number; - ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers"); // @TODO: do they? - return res; - } - - operator bool() const { return n > 0; } // check whether it is specified - - operator std::string() const { // convert back for storing in config - switch(unit) { - case SchedulingUnit::trgLabels: return std::to_string(n) + "t"; - case SchedulingUnit::updates : return std::to_string(n) + "u"; - case SchedulingUnit::epochs : return std::to_string(n) + "e"; - default: ABORT("corrupt enum value for scheduling unit"); - } - } -}; - class TrainingState { public: // Current epoch diff --git a/src/translator/translator.h b/src/translator/translator.h index 82d9343d..fe01065b 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -7,6 +7,9 @@ #include "data/shortlist.h" #include "data/text_input.h" +#include "common/scheduling_parameter.h" +#include "common/timer.h" + #include "3rd_party/threadpool.h" #include "translator/history.h" @@ -130,11 +133,42 @@ public: if(options_->get<bool>("quiet-translation")) collector->setPrintingStrategy(New<QuietPrinting>()); - bg.prepare(); + // mutex for syncing counter and timer updates + std::mutex syncCounts; + + // timer and counters for total elapsed time and statistics + std::unique_ptr<timer::Timer> totTimer(new timer::Timer()); + size_t totBatches = 0; + size_t totLines = 0; + size_t totSourceTokens = 0; + + // timer and counters for elapsed time and statistics between updates + std::unique_ptr<timer::Timer> curTimer(new timer::Timer()); + size_t curBatches = 0; + size_t curLines = 0; + size_t curSourceTokens = 0; + + // determine if we want to display timer statistics, by default off + auto statFreq = SchedulingParameter::parse(options_->get<std::string>("stat-freq", "0u")); + // abort early to avoid potentially costly batching and translation before error message + ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq); + + // Override display for progress heartbeat for MS-internal Philly compute cluster + // otherwise this job may be killed prematurely if no log for 4 hrs + if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster + if(statFreq.n == 0) { + statFreq.n = 10000; + statFreq.unit = SchedulingUnit::updates; + } + } bool doNbest = options_->get<bool>("n-best"); + + bg.prepare(); for(auto batch : bg) { - auto task = [=](size_t id) { + auto task = [=, &syncCounts, + &totBatches, &totLines, &totSourceTokens, &totTimer, + &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) { thread_local Ptr<ExpressionGraph> graph; thread_local std::vector<Ptr<Scorer>> scorers; @@ -156,20 +190,44 @@ public: doNbest); } - - // progress heartbeat for MS-internal Philly compute cluster - // otherwise this job may be killed prematurely if no log for 4 hrs - if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster - && id % 1000 == 0) // hard beat once every 1000 batches - { - auto progress = 0.f; //fake progress for now - fprintf(stderr, "PROGRESS: %.2f%%\n", progress); - fflush(stderr); + // if we asked for speed information display this + if(statFreq.n > 0) { + std::lock_guard<std::mutex> lock(syncCounts); + totBatches++; + totLines += batch->size(); + totSourceTokens += batch->front()->batchWords(); + + curBatches++; + curLines += batch->size(); + curSourceTokens += batch->front()->batchWords(); + + if(totBatches % statFreq.n == 0) { + double totTime = totTimer->elapsed(); + double curTime = curTimer->elapsed(); + + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime); + + // reset stats between updates + curBatches = curLines = curSourceTokens = 0; + curTimer.reset(new timer::Timer()); + } } }; threadPool.enqueue(task, batchId++); + } + // make sure threads are joined before other local variables get de-allocated + threadPool.join_all(); + + // display final speed numbers over total translation if intermediate displays were requested + if(statFreq.n > 0) { + double totTime = totTimer->elapsed(); + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime); } } }; |