Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2021-11-22 22:19:58 +0300
committerGitHub <noreply@github.com>2021-11-22 22:19:58 +0300
commit1adf80b7c9d2b3fc688cf16114e5e9b01425f3a2 (patch)
tree8eb42e0207360f984c62ad6ecc480369a2935f1c
parent3d15cd3d2020abf561b7e0d7ffa87b15baf0dfb1 (diff)
Task alias validation during training mode (#886)
* Attempt to validate task alias * Validate allowed options for --task alias * Update comment in aliases.cpp * Show allowed values for alias Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
-rw-r--r--CHANGELOG.md1
-rw-r--r--src/common/aliases.cpp2
-rw-r--r--src/common/cli_wrapper.cpp17
-rw-r--r--src/common/config_parser.cpp3
4 files changed, 21 insertions, 2 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 169a1a5e..4c624954 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fixed loading binary models on architectures where `size_t` != `uint64_t`.
- Missing float template specialisation for elem::Plus
- Broken links to MNIST data sets
+- Enforce validation for the task alias in training mode.
### Changed
- Optimize LSH for speed by treating is as a shortlist generator. No option changes in decoder
diff --git a/src/common/aliases.cpp b/src/common/aliases.cpp
index 0be26a8c..36613327 100644
--- a/src/common/aliases.cpp
+++ b/src/common/aliases.cpp
@@ -19,6 +19,8 @@ namespace marian {
* As aliases are key-value pairs by default, values are compared as std::string.
* If the command line option corresponding to the alias is a vector, the alias
* will be triggered if the requested value exists in that vector at least once.
+ * By design if an option value that is not defined for that alias option below
+ * is used, the CLI parser will abort with 'unknown value for alias' error.
*
* @see CLIWrapper::alias()
*
diff --git a/src/common/cli_wrapper.cpp b/src/common/cli_wrapper.cpp
index 9a5a1a2c..211dd0b9 100644
--- a/src/common/cli_wrapper.cpp
+++ b/src/common/cli_wrapper.cpp
@@ -132,8 +132,14 @@ void CLIWrapper::parseAliases() {
if(aliases_.empty())
return;
+ // Find the set of values allowed for each alias option.
+ // Later we will check and abort if an alias option has an unknown value.
+ std::unordered_map<std::string, std::unordered_set<std::string>> allowedAliasValues;
+ for(auto &&alias : aliases_)
+ allowedAliasValues[alias.key].insert(alias.value);
+
// Iterate all known aliases, each alias has a key, value, and config
- for(const auto &alias : aliases_) {
+ for(auto &&alias : aliases_) {
// Check if the alias option exists in the config (it may come from command line or a config
// file)
if(config_[alias.key]) {
@@ -145,6 +151,15 @@ void CLIWrapper::parseAliases() {
bool expand = false;
if(config_[alias.key].IsSequence()) {
auto aliasOpts = config_[alias.key].as<std::vector<std::string>>();
+ // Abort if an alias option has an unknown value, i.e. value that has not been defined
+ // in common/aliases.cpp
+ for(auto &&aliasOpt : aliasOpts)
+ if(allowedAliasValues[alias.key].count(aliasOpt) == 0) {
+ std::vector<std::string> allowedOpts(allowedAliasValues[alias.key].begin(),
+ allowedAliasValues[alias.key].end());
+ ABORT("Unknown value '" + aliasOpt + "' for alias option --" + alias.key + ". "
+ "Allowed values: " + utils::join(allowedOpts, ", "));
+ }
expand = std::find(aliasOpts.begin(), aliasOpts.end(), alias.value) != aliasOpts.end();
} else {
expand = config_[alias.key].as<std::string>() == alias.value;
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 30d77e36..8da9520c 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -557,7 +557,8 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
addSuboptionsULR(cli);
cli.add<std::vector<std::string>>("--task",
- "Use predefined set of options. Possible values: transformer, transformer-big");
+ "Use predefined set of options. Possible values: transformer-base, transformer-big, "
+ "transformer-base-prenorm, transformer-big-prenorm");
cli.switchGroup(previous_group);
// clang-format on
}