1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
|
#include "common/cli_wrapper.h"
#include "common/cli_helper.h"
#include "common/logging.h"
#include "common/options.h"
#include "common/timer.h"
#include "common/utils.h"
#include "common/version.h"
namespace marian {
namespace cli {
// clang-format off
const std::unordered_set<std::string> DEPRECATED_OPTIONS = {
"version",
"special-vocab",
// @TODO: uncomment once we actually deprecate them.
// "after-batches",
// "after-epochs"
};
// clang-format on
/*
static uint16_t guess_terminal_width(uint16_t max_width, uint16_t default_width) {
uint16_t cols = 0;
#ifdef TIOCGSIZE
struct ttysize ts;
ioctl(STDIN_FILENO, TIOCGSIZE, &ts);
if(ts.ts_cols != 0)
cols = ts.ts_cols;
#elif defined(TIOCGWINSZ)
struct winsize ts;
ioctl(STDIN_FILENO, TIOCGWINSZ, &ts);
if(ts.ws_col != 0)
cols = ts.ws_col;
#endif
// couldn't determine terminal width
if(cols == 0)
cols = default_width;
return max_width ? std::min(cols, max_width) : cols;
}
*/
CLIFormatter::CLIFormatter(size_t columnWidth, size_t screenWidth)
: CLI::Formatter(), screenWidth_(screenWidth) {
column_width(columnWidth);
}
std::string CLIFormatter::make_option_desc(const CLI::Option *opt) const {
auto desc = opt->get_description();
// TODO: restore guessing terminal width
// wrap lines in the option description
if(screenWidth_ > 0 && screenWidth_ < desc.size() + get_column_width()) {
size_t maxWidth = screenWidth_ - get_column_width();
std::istringstream descIn(desc);
std::ostringstream descOut;
size_t len = 0;
std::string word;
while(descIn >> word) {
if(len > 0)
descOut << " ";
if(len + word.length() > maxWidth) {
descOut << '\n' << std::string(get_column_width(), ' ');
len = 0;
}
descOut << word;
len += word.length() + 1;
}
desc = descOut.str();
}
return desc;
}
CLIWrapper::CLIWrapper(YAML::Node &config,
const std::string &description,
const std::string &header,
const std::string &footer,
size_t columnWidth,
size_t screenWidth)
: app_(std::make_shared<CLI::App>(description)),
defaultGroup_(header),
currentGroup_(header),
config_(config) {
// set footer
if(!footer.empty())
app_->footer("\n" + footer);
// set group name for the automatically added --help option
app_->get_help_ptr()->group(defaultGroup_);
// set custom failure message
app_->failure_message(failureMessage);
// set custom formatter for help message
auto fmt = std::make_shared<CLIFormatter>(columnWidth, screenWidth);
app_->formatter(fmt);
// add --version option
optVersion_ = app_->add_flag("--version", "Print the version number and exit");
optVersion_->group(defaultGroup_);
}
CLIWrapper::~CLIWrapper() {}
// set current group to name, return previous group
std::string CLIWrapper::switchGroup(std::string name) {
currentGroup_.swap(name);
if (currentGroup_.empty())
currentGroup_ = defaultGroup_;
return name;
}
void CLIWrapper::parse(int argc, char **argv) {
try {
app_->parse(argc, argv);
} catch(const CLI::ParseError &e) {
exit(app_->exit(e));
}
// handle --version flag
if(optVersion_->count()) {
std::cerr << buildVersion() << std::endl;
exit(0);
}
}
void CLIWrapper::parseAliases() {
// Exit if no aliases defined
if(aliases_.empty())
return;
// Find the set of values allowed for each alias option.
// Later we will check and abort if an alias option has an unknown value.
std::unordered_map<std::string, std::unordered_set<std::string>> allowedAliasValues;
for(auto &&alias : aliases_)
allowedAliasValues[alias.key].insert(alias.value);
// Iterate all known aliases, each alias has a key, value, and config
for(auto &&alias : aliases_) {
// Check if the alias option exists in the config (it may come from command line or a config
// file)
if(config_[alias.key]) {
// Check if the option in the config stores the value required to expand the alias. If so,
// expand the alias.
// Two cases:
// * the option is a sequence: extract it as a vector of strings and look for the value
// * otherwise: compare values as strings
bool expand = false;
if(config_[alias.key].IsSequence()) {
auto aliasOpts = config_[alias.key].as<std::vector<std::string>>();
// Abort if an alias option has an unknown value, i.e. value that has not been defined
// in common/aliases.cpp
for(auto &&aliasOpt : aliasOpts)
if(allowedAliasValues[alias.key].count(aliasOpt) == 0) {
std::vector<std::string> allowedOpts(allowedAliasValues[alias.key].begin(),
allowedAliasValues[alias.key].end());
ABORT("Unknown value '" + aliasOpt + "' for alias option --" + alias.key + ". "
"Allowed values: " + utils::join(allowedOpts, ", "));
}
expand = std::find(aliasOpts.begin(), aliasOpts.end(), alias.value) != aliasOpts.end();
} else {
expand = config_[alias.key].as<std::string>() == alias.value;
}
if(expand) {
// Update global config options with the config associated with the alias. Abort if the
// alias contains an undefined option.
updateConfig(alias.config,
// Priority of each expanded option is the same as the priority of the alias
options_[alias.key].priority,
"Unknown option(s) in alias '" + alias.key + ": " + alias.value + "'");
}
}
}
// Remove aliases from the global config to avoid redundancy when writing/reading config files
for(const auto &alias : aliases_) {
config_.remove(alias.key);
}
}
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
auto cmdOptions = getParsedOptionNames();
// Keep track of unrecognized options from the provided config
std::vector<std::string> unknownOpts;
// Iterate incoming options: they need to be merged into the global config
for(auto it : config) {
auto key = it.first.as<std::string>();
// Skip options specified via command-line to allow overwriting them
if(cmdOptions.count(key))
continue;
// Skip options that might exist in config files generated by older versions of Marian
if(DEPRECATED_OPTIONS.count(key))
continue;
// Check if an incoming option has been defined in CLI
if(options_.count(key)) {
// Do not proceed if the priority of incoming option is not greater than the existing option
if(priority <= options_[key].priority) {
continue;
}
// Check if the option exists in the global config and types match
if(config_[key] && config_[key].Type() == it.second.Type()) {
config_[key] = YAML::Clone(it.second);
options_[key].priority = priority;
// If types doesn't match, try to convert
} else {
// Default value is a sequence and incoming node is a scalar, hence we can upcast to
// single element sequence
if(config_[key].IsSequence() && it.second.IsScalar()) {
// create single element sequence
YAML::Node sequence;
sequence.push_back(YAML::Clone(it.second));
config_[key] = sequence; // overwrite to replace default values
options_[key].priority = priority;
} else {
// Cannot convert other non-matching types, e.g. scalar <- list should fail
ABORT("Cannot convert values for the option: " + key);
}
}
} else { // an unknown option
unknownOpts.push_back(key);
}
}
ABORT_IF(!unknownOpts.empty(), errorMsg + ": " + utils::join(unknownOpts, ", "));
}
std::string CLIWrapper::dumpConfig(bool skipUnmodified /*= false*/) const {
YAML::Emitter out;
out << YAML::Comment("Marian configuration file generated at " + timer::currentDate()
+ " with version " + buildVersion());
out << YAML::BeginMap;
std::string comment;
// Iterate option names in the same order as they have been created
for(const auto &key : getOrderedOptionNames()) {
// Do not dump options that were removed from config_
if(!config_[key])
continue;
// Do not dump options that were not passed via the command line
if(skipUnmodified && options_.at(key).priority == cli::OptionPriority::DefaultValue)
continue;
// Put the group name as a comment before the first option in the group
auto group = options_.at(key).opt->get_group();
if(comment != group) {
if(!comment.empty())
out << YAML::Newline;
comment = group;
out << YAML::Comment(group);
}
out << YAML::Key;
out << key;
out << YAML::Value;
cli::OutputYaml(config_[key], out);
}
out << YAML::EndMap;
return out.c_str();
}
std::unordered_set<std::string> CLIWrapper::getParsedOptionNames() const {
std::unordered_set<std::string> keys;
for(const auto &it : options_)
if(!it.second.opt->empty())
keys.emplace(it.first);
return keys;
}
std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
std::vector<std::string> keys;
// extract all option names
for(auto const &it : options_)
keys.push_back(it.first);
// sort option names by creation index
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) {
return options_.at(a).idx < options_.at(b).idx;
});
return keys;
}
std::string CLIWrapper::failureMessage(const CLI::App *app, const CLI::Error &e) {
std::string header = "Error: " + std::string(e.what()) + "\n";
if(app->get_help_ptr() != nullptr)
header += "Run with " + app->get_help_ptr()->get_name() + " for more information.\n";
return header;
}
} // namespace cli
} // namespace marian
|