Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-03-22 18:58:04 +0300
committerGitHub <noreply@github.com>2021-03-22 18:58:04 +0300
commit0394d2cdbe42b7ad7fc02854da8e063f2a23496f (patch)
treeaf669285a1906ba07e0397bd19e1ef33ae9b633a /src
parent096c48e51cd2e61bb275345d7cca99cbfd6bc5c7 (diff)
Display decoder speed statistics with --stat-freq N (#841)
Display decoder time statistics if requested
Diffstat (limited to 'src')
-rw-r--r--src/common/config_parser.cpp3
-rw-r--r--src/common/scheduling_parameter.h53
-rw-r--r--src/training/training_state.h44
-rw-r--r--src/translator/translator.h80
4 files changed, 126 insertions, 54 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 3baa13ea..602509c5 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -658,6 +658,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
->implicit_val("1");
cli.add<bool>("--word-scores",
"Print word-level scores. One score per subword unit, not normalized even if --normalize");
+ cli.add<std::string/*SchedulerPeriod*/>("--stat-freq",
+ "Display speed information every arg mini-batches. Disabled by default with 0, set to value larger than 0 to activate",
+ "0");
#ifdef USE_SENTENCEPIECE
cli.add<bool>("--no-spm-decode",
"Keep the output segmented into SentencePiece subwords");
diff --git a/src/common/scheduling_parameter.h b/src/common/scheduling_parameter.h
new file mode 100644
index 00000000..b6e39acf
--- /dev/null
+++ b/src/common/scheduling_parameter.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include "common/logging.h"
+#include "common/utils.h"
+
+#include <string>
+
+namespace marian {
+
+// support for scheduling parameters that can be expressed with a unit, such as --lr-decay-inv-sqrt
+enum class SchedulingUnit {
+ trgLabels, // "t": number of target labels seen so far
+ updates, // "u": number of updates so far (batches)
+ epochs // "e": number of epochs begun so far (very first epoch is 1)
+};
+
+struct SchedulingParameter {
+ size_t n{0}; // number of steps measured in 'unit'
+ SchedulingUnit unit{SchedulingUnit::updates}; // unit of value
+
+ // parses scheduling parameters of the form NU where N=unsigned int and U=unit
+ // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels),
+ // "100e" (100 epochs).
+ static SchedulingParameter parse(std::string param) {
+ SchedulingParameter res;
+ if(!param.empty() && param.back() >= 'a') {
+ switch(param.back()) {
+ case 't': res.unit = SchedulingUnit::trgLabels; break;
+ case 'u': res.unit = SchedulingUnit::updates; break;
+ case 'e': res.unit = SchedulingUnit::epochs; break;
+ default: ABORT("invalid unit '{}' in {}", param.back(), param);
+ }
+ param.pop_back();
+ }
+ double number = utils::parseNumber(param);
+ res.n = (size_t)number;
+ ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers"); // @TODO: do they?
+ return res;
+ }
+
+ operator bool() const { return n > 0; } // check whether it is specified
+
+ operator std::string() const { // convert back for storing in config
+ switch(unit) {
+ case SchedulingUnit::trgLabels: return std::to_string(n) + "t";
+ case SchedulingUnit::updates : return std::to_string(n) + "u";
+ case SchedulingUnit::epochs : return std::to_string(n) + "e";
+ default: ABORT("corrupt enum value for scheduling unit");
+ }
+ }
+};
+
+} \ No newline at end of file
diff --git a/src/training/training_state.h b/src/training/training_state.h
index 459e33cf..7d62f060 100644
--- a/src/training/training_state.h
+++ b/src/training/training_state.h
@@ -2,6 +2,7 @@
#include "common/definitions.h"
#include "common/filesystem.h"
+#include "common/scheduling_parameter.h"
#include "common/utils.h"
#include <fstream>
@@ -22,49 +23,6 @@ public:
virtual void actAfterLoaded(TrainingState&) {}
};
-// support for scheduling parameters that can be expressed with a unit, such as --lr-decay-inv-sqrt
-enum class SchedulingUnit {
- trgLabels, // "t": number of target labels seen so far
- updates, // "u": number of updates so far (batches)
- epochs // "e": number of epochs begun so far (very first epoch is 1)
-};
-
-struct SchedulingParameter {
- size_t n{0}; // number of steps measured in 'unit'
- SchedulingUnit unit{SchedulingUnit::updates}; // unit of value
-
- // parses scheduling parameters of the form NU where N=unsigned int and U=unit
- // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels),
- // "100e" (100 epochs).
- static SchedulingParameter parse(std::string param) {
- SchedulingParameter res;
- if(!param.empty() && param.back() >= 'a') {
- switch(param.back()) {
- case 't': res.unit = SchedulingUnit::trgLabels; break;
- case 'u': res.unit = SchedulingUnit::updates; break;
- case 'e': res.unit = SchedulingUnit::epochs; break;
- default: ABORT("invalid unit '{}' in {}", param.back(), param);
- }
- param.pop_back();
- }
- double number = utils::parseNumber(param);
- res.n = (size_t)number;
- ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers"); // @TODO: do they?
- return res;
- }
-
- operator bool() const { return n > 0; } // check whether it is specified
-
- operator std::string() const { // convert back for storing in config
- switch(unit) {
- case SchedulingUnit::trgLabels: return std::to_string(n) + "t";
- case SchedulingUnit::updates : return std::to_string(n) + "u";
- case SchedulingUnit::epochs : return std::to_string(n) + "e";
- default: ABORT("corrupt enum value for scheduling unit");
- }
- }
-};
-
class TrainingState {
public:
// Current epoch
diff --git a/src/translator/translator.h b/src/translator/translator.h
index 82d9343d..fe01065b 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -7,6 +7,9 @@
#include "data/shortlist.h"
#include "data/text_input.h"
+#include "common/scheduling_parameter.h"
+#include "common/timer.h"
+
#include "3rd_party/threadpool.h"
#include "translator/history.h"
@@ -130,11 +133,42 @@ public:
if(options_->get<bool>("quiet-translation"))
collector->setPrintingStrategy(New<QuietPrinting>());
- bg.prepare();
+ // mutex for syncing counter and timer updates
+ std::mutex syncCounts;
+
+ // timer and counters for total elapsed time and statistics
+ std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
+ size_t totBatches = 0;
+ size_t totLines = 0;
+ size_t totSourceTokens = 0;
+
+ // timer and counters for elapsed time and statistics between updates
+ std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
+ size_t curBatches = 0;
+ size_t curLines = 0;
+ size_t curSourceTokens = 0;
+
+ // determine if we want to display timer statistics, by default off
+ auto statFreq = SchedulingParameter::parse(options_->get<std::string>("stat-freq", "0u"));
+ // abort early to avoid potentially costly batching and translation before error message
+ ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq);
+
+ // Override display for progress heartbeat for MS-internal Philly compute cluster
+ // otherwise this job may be killed prematurely if no log for 4 hrs
+ if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster
+ if(statFreq.n == 0) {
+ statFreq.n = 10000;
+ statFreq.unit = SchedulingUnit::updates;
+ }
+ }
bool doNbest = options_->get<bool>("n-best");
+
+ bg.prepare();
for(auto batch : bg) {
- auto task = [=](size_t id) {
+ auto task = [=, &syncCounts,
+ &totBatches, &totLines, &totSourceTokens, &totTimer,
+ &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@@ -156,20 +190,44 @@ public:
doNbest);
}
-
- // progress heartbeat for MS-internal Philly compute cluster
- // otherwise this job may be killed prematurely if no log for 4 hrs
- if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
- && id % 1000 == 0) // hard beat once every 1000 batches
- {
- auto progress = 0.f; //fake progress for now
- fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
- fflush(stderr);
+ // if we asked for speed information display this
+ if(statFreq.n > 0) {
+ std::lock_guard<std::mutex> lock(syncCounts);
+ totBatches++;
+ totLines += batch->size();
+ totSourceTokens += batch->front()->batchWords();
+
+ curBatches++;
+ curLines += batch->size();
+ curSourceTokens += batch->front()->batchWords();
+
+ if(totBatches % statFreq.n == 0) {
+ double totTime = totTimer->elapsed();
+ double curTime = curTimer->elapsed();
+
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
+
+ // reset stats between updates
+ curBatches = curLines = curSourceTokens = 0;
+ curTimer.reset(new timer::Timer());
+ }
}
};
threadPool.enqueue(task, batchId++);
+ }
+ // make sure threads are joined before other local variables get de-allocated
+ threadPool.join_all();
+
+ // display final speed numbers over total translation if intermediate displays were requested
+ if(statFreq.n > 0) {
+ double totTime = totTimer->elapsed();
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
}
}
};