Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-05-31 18:53:07 +0300
committerTaku Kudo <taku@google.com>2020-05-31 18:53:07 +0300
commitd36b81fdf338e1ce7c3b08ff0bbf0a94cb5b1cf9 (patch)
tree5c447782068e2738427c121f4dfe20a703e3c2a6
parent706b35a37271170175baa3f11f823428f0c9041a (diff)
Use absl::flags
-rw-r--r--src/CMakeLists.txt9
-rw-r--r--src/bpe_model_trainer_test.cc18
-rw-r--r--src/builder_test.cc39
-rw-r--r--src/char_model_trainer_test.cc6
-rw-r--r--src/common.h12
-rw-r--r--src/compile_charsmap_main.cc10
-rw-r--r--src/filesystem_test.cc4
-rw-r--r--src/flags.cc239
-rw-r--r--src/flags.h88
-rw-r--r--src/init.cc32
-rw-r--r--src/init.h23
-rw-r--r--src/init_test.cc (renamed from src/flags_test.cc)72
-rw-r--r--src/sentencepiece_processor_test.cc9
-rw-r--r--src/sentencepiece_trainer.cc5
-rw-r--r--src/sentencepiece_trainer_test.cc37
-rw-r--r--src/spm_decode_main.cc59
-rw-r--r--src/spm_encode_main.cc102
-rw-r--r--src/spm_export_vocab_main.cc28
-rw-r--r--src/spm_normalize_main.cc63
-rw-r--r--src/spm_train_main.cc227
-rw-r--r--src/test_main.cc10
-rw-r--r--src/testharness.cc6
-rw-r--r--src/testharness.h17
-rw-r--r--src/trainer_interface_test.cc11
-rw-r--r--src/unigram_model_trainer_test.cc11
-rw-r--r--src/util_test.cc4
-rw-r--r--src/word_model_trainer_test.cc6
-rw-r--r--third_party/absl/flags/flag.cc220
-rw-r--r--third_party/absl/flags/flag.h64
29 files changed, 749 insertions, 682 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index e0bfd5e..07316a1 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -78,7 +78,7 @@ set(SPM_SRCS
util.h
freelist.h
filesystem.h
- flags.h
+ init.h
sentencepiece_processor.h
word_model.h
model_factory.h
@@ -90,7 +90,7 @@ set(SPM_SRCS
char_model.cc
error.cc
filesystem.cc
- flags.cc
+ init.cc
model_factory.cc
model_interface.cc
normalizer.cc
@@ -98,7 +98,8 @@ set(SPM_SRCS
unigram_model.cc
util.cc
word_model.cc
- ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/absl/strings/string_view.cc)
+ ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/absl/strings/string_view.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/absl/flags/flag.cc)
set(SPM_TRAIN_SRCS
${SPM_PROTO_HDRS}
@@ -136,7 +137,7 @@ set(SPM_TEST_SRCS
char_model_test.cc
char_model_trainer_test.cc
filesystem_test.cc
- flags_test.cc
+ init_test.cc
model_factory_test.cc
model_interface_test.cc
normalizer_test.cc
diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc
index 027b1df..173eb9c 100644
--- a/src/bpe_model_trainer_test.cc
+++ b/src/bpe_model_trainer_test.cc
@@ -34,8 +34,10 @@ namespace {
std::string RunTrainer(
const std::vector<std::string> &input, int size,
const std::vector<std::string> &user_defined_symbols = {}) {
- const std::string input_file = util::JoinPath(FLAGS_test_tmpdir, "input");
- const std::string model_prefix = util::JoinPath(FLAGS_test_tmpdir, "model");
+ const std::string input_file =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
+ const std::string model_prefix =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
@@ -90,12 +92,14 @@ TEST(BPETrainerTest, BasicTest) {
static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";
TEST(BPETrainerTest, EndToEndTest) {
- const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestInputData);
+ const std::string input =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);
ASSERT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat(
- "--model_prefix=", util::JoinPath(FLAGS_test_tmpdir, "tmp_model"),
+ "--model_prefix=",
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
" --input=", input,
" --vocab_size=8000 --normalization_rule_name=identity"
" --model_type=bpe --control_symbols=<ctrl> "
@@ -103,9 +107,9 @@ TEST(BPETrainerTest, EndToEndTest) {
.ok());
SentencePieceProcessor sp;
- ASSERT_TRUE(
- sp.Load(std::string(util::JoinPath(FLAGS_test_tmpdir, "tmp_model.model")))
- .ok());
+ ASSERT_TRUE(sp.Load(std::string(util::JoinPath(
+ absl::GetFlag(FLAGS_test_tmpdir), "tmp_model.model")))
+ .ok());
EXPECT_EQ(8000, sp.GetPieceSize());
const int cid = sp.PieceToId("<ctrl>");
diff --git a/src/builder_test.cc b/src/builder_test.cc
index 5dea084..4acb7b3 100644
--- a/src/builder_test.cc
+++ b/src/builder_test.cc
@@ -141,9 +141,11 @@ static constexpr char kTestInputData[] = "nfkc.tsv";
TEST(BuilderTest, LoadCharsMapTest) {
Builder::CharsMap chars_map;
- ASSERT_TRUE(Builder::LoadCharsMap(
- util::JoinPath(FLAGS_test_srcdir, kTestInputData), &chars_map)
- .ok());
+ ASSERT_TRUE(
+ Builder::LoadCharsMap(
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData),
+ &chars_map)
+ .ok());
std::string precompiled, expected;
ASSERT_TRUE(Builder::CompileCharsMap(chars_map, &precompiled).ok());
@@ -154,14 +156,17 @@ TEST(BuilderTest, LoadCharsMapTest) {
Builder::DecompileCharsMap(precompiled, &decompiled_chars_map).ok());
EXPECT_EQ(chars_map, decompiled_chars_map);
- ASSERT_TRUE(Builder::SaveCharsMap(
- util::JoinPath(FLAGS_test_tmpdir, "output.tsv"), chars_map)
- .ok());
+ ASSERT_TRUE(
+ Builder::SaveCharsMap(
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
+ chars_map)
+ .ok());
Builder::CharsMap saved_chars_map;
ASSERT_TRUE(
- Builder::LoadCharsMap(util::JoinPath(FLAGS_test_tmpdir, "output.tsv"),
- &saved_chars_map)
+ Builder::LoadCharsMap(
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
+ &saved_chars_map)
.ok());
EXPECT_EQ(chars_map, saved_chars_map);
@@ -175,7 +180,7 @@ TEST(BuilderTest, LoadCharsMapTest) {
TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {
{
auto output = filesystem::NewWritableFile(
- util::JoinPath(FLAGS_test_tmpdir, "test.tsv"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"));
output->WriteLine("0061\t0041");
output->WriteLine("0062");
output->WriteLine("0063\t\t#foo=>bar");
@@ -183,7 +188,8 @@ TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {
Builder::CharsMap chars_map;
EXPECT_TRUE(Builder::LoadCharsMap(
- util::JoinPath(FLAGS_test_tmpdir, "test.tsv"), &chars_map)
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"),
+ &chars_map)
.ok());
EXPECT_EQ(3, chars_map.size());
@@ -191,14 +197,17 @@ TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {
EXPECT_EQ(std::vector<char32>({}), chars_map[{0x0062}]);
EXPECT_EQ(std::vector<char32>({}), chars_map[{0x0063}]);
- EXPECT_TRUE(Builder::SaveCharsMap(
- util::JoinPath(FLAGS_test_tmpdir, "test_out.tsv"), chars_map)
- .ok());
+ EXPECT_TRUE(
+ Builder::SaveCharsMap(
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
+ chars_map)
+ .ok());
Builder::CharsMap new_chars_map;
EXPECT_TRUE(
- Builder::LoadCharsMap(util::JoinPath(FLAGS_test_tmpdir, "test_out.tsv"),
- &new_chars_map)
+ Builder::LoadCharsMap(
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
+ &new_chars_map)
.ok());
EXPECT_EQ(chars_map, new_chars_map);
}
diff --git a/src/char_model_trainer_test.cc b/src/char_model_trainer_test.cc
index 6767e29..8c2e4b7 100644
--- a/src/char_model_trainer_test.cc
+++ b/src/char_model_trainer_test.cc
@@ -31,8 +31,10 @@ namespace {
#define WS "\xE2\x96\x81"
std::string RunTrainer(const std::vector<std::string> &input, int size) {
- const std::string input_file = util::JoinPath(FLAGS_test_tmpdir, "input");
- const std::string model_prefix = util::JoinPath(FLAGS_test_tmpdir, "model");
+ const std::string input_file =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
+ const std::string model_prefix =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
diff --git a/src/common.h b/src/common.h
index a1e5d14..5d23e07 100644
--- a/src/common.h
+++ b/src/common.h
@@ -19,6 +19,7 @@
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
+
#include <iostream>
#include <memory>
#include <string>
@@ -26,7 +27,7 @@
#include <vector>
#include "config.h"
-#include "flags.h"
+#include "third_party/absl/flags/flag.h"
#if defined(_WIN32) && !defined(__CYGWIN__)
#define OS_WIN
@@ -90,11 +91,6 @@ std::string WideToUtf8(const std::wstring &input);
} // namespace win32
#endif
-namespace flags {
-int GetMinLogLevel();
-void SetMinLogLevel(int minloglevel);
-} // namespace flags
-
namespace error {
void Abort();
@@ -149,8 +145,10 @@ inline const char *BaseName(const char *path) {
} // namespace logging
} // namespace sentencepiece
+ABSL_DECLARE_FLAG(int32, minloglevel);
+
#define LOG(severity) \
- (sentencepiece::flags::GetMinLogLevel() > \
+ (absl::GetFlag(FLAGS_minloglevel) > \
::sentencepiece::logging::LOG_##severity) \
? 0 \
: ::sentencepiece::error::Die( \
diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc
index e8fc072..c5a5188 100644
--- a/src/compile_charsmap_main.cc
+++ b/src/compile_charsmap_main.cc
@@ -20,13 +20,15 @@
#include "builder.h"
#include "filesystem.h"
-#include "flags.h"
+#include "init.h"
#include "sentencepiece_processor.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/string_view.h"
using sentencepiece::normalizer::Builder;
-DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file");
+ABSL_FLAG(bool, output_precompiled_header, false,
+ "make normalization_rule.h file");
namespace sentencepiece {
namespace {
@@ -154,7 +156,7 @@ struct BinaryBlob {
} // namespace sentencepiece
int main(int argc, char **argv) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
const std::vector<std::pair<
std::string,
@@ -178,7 +180,7 @@ int main(int argc, char **argv) {
CHECK_OK(Builder::SaveCharsMap(p.first + ".tsv", normalized_map));
}
- if (FLAGS_output_precompiled_header) {
+ if (absl::GetFlag(FLAGS_output_precompiled_header)) {
constexpr char kPrecompiledHeaderFileName[] = "normalization_rule.h";
auto output =
sentencepiece::filesystem::NewWritableFile(kPrecompiledHeaderFileName);
diff --git a/src/filesystem_test.cc b/src/filesystem_test.cc
index 9b842fc..790e756 100644
--- a/src/filesystem_test.cc
+++ b/src/filesystem_test.cc
@@ -28,7 +28,7 @@ TEST(UtilTest, FilesystemTest) {
{
auto output = filesystem::NewWritableFile(
- util::JoinPath(FLAGS_test_tmpdir, "test_file"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
for (size_t i = 0; i < kData.size(); ++i) {
output->WriteLine(kData[i]);
}
@@ -36,7 +36,7 @@ TEST(UtilTest, FilesystemTest) {
{
auto input = filesystem::NewReadableFile(
- util::JoinPath(FLAGS_test_tmpdir, "test_file"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
std::string line;
for (size_t i = 0; i < kData.size(); ++i) {
EXPECT_TRUE(input->ReadLine(&line));
diff --git a/src/flags.cc b/src/flags.cc
deleted file mode 100644
index 0904d11..0000000
--- a/src/flags.cc
+++ /dev/null
@@ -1,239 +0,0 @@
-// Copyright 2016 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.!
-
-#include "flags.h"
-
-#include <algorithm>
-#include <cctype>
-#include <iostream>
-#include <map>
-#include <sstream>
-#include <string>
-#include <utility>
-
-#include "common.h"
-#include "config.h"
-#include "util.h"
-
-namespace sentencepiece {
-namespace flags {
-
-struct Flag {
- int type;
- void *storage;
- const void *default_storage;
- std::string help;
-};
-
-static int32 g_minloglevel = 0;
-
-int GetMinLogLevel() { return g_minloglevel; }
-void SetMinLogLevel(int minloglevel) { g_minloglevel = minloglevel; }
-
-namespace {
-using FlagMap = std::map<std::string, Flag *>;
-using RestArgs = std::vector<char *>;
-
-RestArgs *GetRestArgs() {
- static auto *rest_args = new RestArgs;
- return rest_args;
-}
-
-FlagMap *GetFlagMap() {
- static auto *flag_map = new FlagMap;
- return flag_map;
-}
-
-bool SetFlag(const std::string &name, const std::string &value) {
- auto it = GetFlagMap()->find(name);
- if (it == GetFlagMap()->end()) {
- return false;
- }
-
- std::string v = value;
- Flag *flag = it->second;
-
- // If empty value is set, we assume true or emtpy string is set
- // for boolean or string option. With other types, setting fails.
- if (value.empty()) {
- switch (flag->type) {
- case B:
- v = "true";
- break;
- case S:
- v = "";
- break;
- default:
- return false;
- }
- }
-
-#define DEFINE_ARG(FLAG_TYPE, CPP_TYPE) \
- case FLAG_TYPE: { \
- CPP_TYPE *s = reinterpret_cast<CPP_TYPE *>(flag->storage); \
- CHECK(string_util::lexical_cast<CPP_TYPE>(v, s)); \
- break; \
- }
-
- switch (flag->type) {
- DEFINE_ARG(I, int32);
- DEFINE_ARG(B, bool);
- DEFINE_ARG(I64, int64);
- DEFINE_ARG(U64, uint64);
- DEFINE_ARG(D, double);
- DEFINE_ARG(S, std::string);
- default:
- break;
- }
-
- return true;
-} // namespace
-
-bool CommandLineGetFlag(int argc, char **argv, std::string *key,
- std::string *value, int *used_args) {
- key->clear();
- value->clear();
-
- *used_args = 1;
- const char *start = argv[0];
- if (start[0] != '-') {
- return false;
- }
-
- ++start;
- if (start[0] == '-') ++start;
- const std::string arg = start;
- const size_t n = arg.find("=");
- if (n != std::string::npos) {
- *key = arg.substr(0, n);
- *value = arg.substr(n + 1, arg.size() - n);
- return true;
- }
-
- key->assign(arg);
- value->clear();
-
- if (argc == 1) {
- return true;
- }
- start = argv[1];
- if (start[0] == '-') {
- return true;
- }
-
- *used_args = 2;
- value->assign(start);
- return true;
-}
-} // namespace
-
-FlagRegister::FlagRegister(const char *name, void *storage,
- const void *default_storage, int shortype,
- const char *help)
- : flag_(new Flag) {
- flag_->type = shortype;
- flag_->storage = storage;
- flag_->default_storage = default_storage;
- flag_->help = help;
- GetFlagMap()->insert(std::make_pair(std::string(name), flag_.get()));
-}
-
-FlagRegister::~FlagRegister() {}
-
-std::string PrintHelp(const char *programname) {
- std::ostringstream os;
- os << PACKAGE_STRING << "\n\n";
- os << "Usage: " << programname << " [options] files\n\n";
-
- for (const auto &it : *GetFlagMap()) {
- os << " --" << it.first << " (" << it.second->help << ")";
- const Flag *flag = it.second;
- switch (flag->type) {
- case I:
- os << " type: int32 default: "
- << *(reinterpret_cast<const int *>(flag->default_storage)) << '\n';
- break;
- case B:
- os << " type: bool default: "
- << (*(reinterpret_cast<const bool *>(flag->default_storage))
- ? "true"
- : "false")
- << '\n';
- break;
- case I64:
- os << " type: int64 default: "
- << *(reinterpret_cast<const int64 *>(flag->default_storage)) << '\n';
- break;
- case U64:
- os << " type: uint64 default: "
- << *(reinterpret_cast<const uint64 *>(flag->default_storage))
- << '\n';
- break;
- case D:
- os << " type: double default: "
- << *(reinterpret_cast<const double *>(flag->default_storage))
- << '\n';
- break;
- case S:
- os << " type: string default: "
- << *(reinterpret_cast<const std::string *>(flag->default_storage))
- << '\n';
- break;
- default:
- break;
- }
- }
-
- os << "\n\n";
-
- return os.str();
-}
-
-void ParseCommandLineFlags(const char *usage, int *iargc, char ***iargv,
- bool remove_arg) {
- int used_argc = 0;
- std::string key, value;
-
- auto *rest_args = GetRestArgs();
- char **argv = *iargv;
- int argc = *iargc;
-
- rest_args->clear();
- rest_args->push_back(argv[0]);
-
- for (int i = 1; i < argc; i += used_argc) {
- if (!CommandLineGetFlag(argc - i, argv + i, &key, &value, &used_argc)) {
- rest_args->push_back(argv[i]);
- continue;
- }
- if (key == "help") {
- std::cout << PrintHelp(argv[0]);
- error::Exit(0);
- } else if (key == "version") {
- std::cout << PACKAGE_STRING << " " << VERSION << std::endl;
- error::Exit(0);
- } else if (key == "minloglevel") {
- flags::SetMinLogLevel(atoi(value.c_str()));
- } else if (!SetFlag(key, value)) {
- std::cerr << "Unknown/Invalid flag " << key << "\n\n"
- << PrintHelp(argv[0]);
- error::Exit(1);
- }
- }
-
- *iargv = rest_args->data();
- *iargc = static_cast<int>(rest_args->size());
-}
-} // namespace flags
-} // namespace sentencepiece
diff --git a/src/flags.h b/src/flags.h
deleted file mode 100644
index 59fac17..0000000
--- a/src/flags.h
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright 2016 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.!
-
-#ifndef FLAGS_H_
-#define FLAGS_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-namespace sentencepiece {
-namespace flags {
-
-enum { I, B, I64, U64, D, S };
-
-struct Flag;
-
-class FlagRegister {
- public:
- FlagRegister(const char *name, void *storage, const void *default_storage,
- int shorttpe, const char *help);
- ~FlagRegister();
-
- private:
- std::unique_ptr<Flag> flag_;
-};
-
-std::string PrintHelp(const char *programname);
-
-void ParseCommandLineFlags(const char *usage, int *argc, char ***argv,
- bool remvoe_flags = true);
-} // namespace flags
-} // namespace sentencepiece
-
-#define DEFINE_VARIABLE(type, shorttype, name, value, help) \
- namespace sentencepiece_flags_fL##shorttype { \
- using namespace sentencepiece::flags; \
- type FLAGS_##name = value; \
- static const type FLAGS_DEFAULT_##name = value; \
- static const sentencepiece::flags::FlagRegister fL##name( \
- #name, reinterpret_cast<void *>(&FLAGS_##name), \
- reinterpret_cast<const void *>(&FLAGS_DEFAULT_##name), shorttype, \
- help); \
- } \
- using sentencepiece_flags_fL##shorttype::FLAGS_##name
-
-#define DECLARE_VARIABLE(type, shorttype, name) \
- namespace sentencepiece_flags_fL##shorttype { \
- extern type FLAGS_##name; \
- } \
- using sentencepiece_flags_fL##shorttype::FLAGS_##name
-
-#define DEFINE_int32(name, value, help) \
- DEFINE_VARIABLE(int32, I, name, value, help)
-#define DECLARE_int32(name) DECLARE_VARIABLE(int32, I, name)
-
-#define DEFINE_int64(name, value, help) \
- DEFINE_VARIABLE(int64, I64, name, value, help)
-#define DECLARE_int64(name) DECLARE_VARIABLE(int64, I64, name)
-
-#define DEFINE_uint64(name, value, help) \
- DEFINE_VARIABLE(uint64, U64, name, value, help)
-#define DECLARE_uint64(name) DECLARE_VARIABLE(uint64, U64, name)
-
-#define DEFINE_double(name, value, help) \
- DEFINE_VARIABLE(double, D, name, value, help)
-#define DECLARE_double(name) DECLARE_VARIABLE(double, D, name)
-
-#define DEFINE_bool(name, value, help) \
- DEFINE_VARIABLE(bool, B, name, value, help)
-#define DECLARE_bool(name) DECLARE_VARIABLE(bool, B, name)
-
-#define DEFINE_string(name, value, help) \
- DEFINE_VARIABLE(std::string, S, name, value, help)
-#define DECLARE_string(name) DECLARE_VARIABLE(std::string, S, name)
-
-#endif // FLAGS_H_
diff --git a/src/init.cc b/src/init.cc
new file mode 100644
index 0000000..f1800c5
--- /dev/null
+++ b/src/init.cc
@@ -0,0 +1,32 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#include "init.h"
+
+#include "third_party/absl/flags/flag.h"
+
+namespace sentencepiece {
+
+void ParseCommandLineFlags(const char *usage, int *argc, char ***argv,
+ bool remove_arg) {
+ const auto unused_args = absl::ParseCommandLine(*argc, *argv);
+
+ if (remove_arg) {
+ char **argv_val = *argv;
+ *argv = argv_val = argv_val + *argc - unused_args.size();
+ std::copy(unused_args.begin(), unused_args.end(), argv_val);
+ *argc = static_cast<int>(unused_args.size());
+ }
+}
+} // namespace sentencepiece
diff --git a/src/init.h b/src/init.h
new file mode 100644
index 0000000..a569c22
--- /dev/null
+++ b/src/init.h
@@ -0,0 +1,23 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#ifndef INIT_H_
+#define INIT_H_
+
+namespace sentencepiece {
+void ParseCommandLineFlags(const char *usage, int *argc, char ***argv,
+ bool remvoe_flags = true);
+} // namespace sentencepiece
+
+#endif // INIT_H_
diff --git a/src/flags_test.cc b/src/init_test.cc
index 51f0b21..da659bf 100644
--- a/src/flags_test.cc
+++ b/src/init_test.cc
@@ -12,39 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "flags.h"
+#include "init.h"
#include "common.h"
#include "testharness.h"
-DEFINE_int32(int32_f, 10, "int32_flags");
-DEFINE_bool(bool_f, false, "bool_flags");
-DEFINE_int64(int64_f, 20, "int64_flags");
-DEFINE_uint64(uint64_f, 30, "uint64_flags");
-DEFINE_double(double_f, 40.0, "double_flags");
-DEFINE_string(string_f, "str", "string_flags");
+ABSL_FLAG(int32, int32_f, 10, "int32_flags");
+ABSL_FLAG(bool, bool_f, false, "bool_flags");
+ABSL_FLAG(int64, int64_f, 20, "int64_flags");
+ABSL_FLAG(uint64, uint64_f, 30, "uint64_flags");
+ABSL_FLAG(double, double_f, 40.0, "double_flags");
+ABSL_FLAG(std::string, string_f, "str", "string_flags");
-namespace sentencepiece {
-namespace flags {
+using sentencepiece::ParseCommandLineFlags;
+namespace absl {
TEST(FlagsTest, DefaultValueTest) {
- EXPECT_EQ(10, FLAGS_int32_f);
- EXPECT_EQ(false, FLAGS_bool_f);
- EXPECT_EQ(20, FLAGS_int64_f);
- EXPECT_EQ(30, FLAGS_uint64_f);
- EXPECT_EQ(40.0, FLAGS_double_f);
- EXPECT_EQ("str", FLAGS_string_f);
-}
-
-TEST(FlagsTest, PrintHelpTest) {
- const std::string help = PrintHelp("foobar");
- EXPECT_NE(std::string::npos, help.find("foobar"));
- EXPECT_NE(std::string::npos, help.find("int32_flags"));
- EXPECT_NE(std::string::npos, help.find("bool_flags"));
- EXPECT_NE(std::string::npos, help.find("int64_flags"));
- EXPECT_NE(std::string::npos, help.find("uint64_flags"));
- EXPECT_NE(std::string::npos, help.find("double_flags"));
- EXPECT_NE(std::string::npos, help.find("string_flags"));
+ EXPECT_EQ(10, absl::GetFlag(FLAGS_int32_f));
+ EXPECT_EQ(false, absl::GetFlag(FLAGS_bool_f));
+ EXPECT_EQ(20, absl::GetFlag(FLAGS_int64_f));
+ EXPECT_EQ(30, absl::GetFlag(FLAGS_uint64_f));
+ EXPECT_EQ(40.0, absl::GetFlag(FLAGS_double_f));
+ EXPECT_EQ("str", absl::GetFlag(FLAGS_string_f));
}
TEST(FlagsTest, ParseCommandLineFlagsTest) {
@@ -56,12 +45,12 @@ TEST(FlagsTest, ParseCommandLineFlagsTest) {
char **argv = const_cast<char **>(kFlags);
ParseCommandLineFlags(kFlags[0], &argc, &argv);
- EXPECT_EQ(100, FLAGS_int32_f);
- EXPECT_EQ(true, FLAGS_bool_f);
- EXPECT_EQ(200, FLAGS_int64_f);
- EXPECT_EQ(300, FLAGS_uint64_f);
- EXPECT_EQ(400.0, FLAGS_double_f);
- EXPECT_EQ("foo", FLAGS_string_f);
+ EXPECT_EQ(100, absl::GetFlag(FLAGS_int32_f));
+ EXPECT_EQ(true, absl::GetFlag(FLAGS_bool_f));
+ EXPECT_EQ(200, absl::GetFlag(FLAGS_int64_f));
+ EXPECT_EQ(300, absl::GetFlag(FLAGS_uint64_f));
+ EXPECT_EQ(400.0, absl::GetFlag(FLAGS_double_f));
+ EXPECT_EQ("foo", absl::GetFlag(FLAGS_string_f));
EXPECT_EQ(4, argc);
EXPECT_EQ("program", std::string(argv[0]));
EXPECT_EQ("other1", std::string(argv[1]));
@@ -77,10 +66,10 @@ TEST(FlagsTest, ParseCommandLineFlagsTest2) {
char **argv = const_cast<char **>(kFlags);
ParseCommandLineFlags(kFlags[0], &argc, &argv);
- EXPECT_EQ(500, FLAGS_int32_f);
- EXPECT_EQ(600, FLAGS_int64_f);
- EXPECT_EQ(700, FLAGS_uint64_f);
- EXPECT_FALSE(FLAGS_bool_f);
+ EXPECT_EQ(500, absl::GetFlag(FLAGS_int32_f));
+ EXPECT_EQ(600, absl::GetFlag(FLAGS_int64_f));
+ EXPECT_EQ(700, absl::GetFlag(FLAGS_uint64_f));
+ EXPECT_FALSE(absl::GetFlag(FLAGS_bool_f));
EXPECT_EQ(1, argc);
}
@@ -90,8 +79,8 @@ TEST(FlagsTest, ParseCommandLineFlagsTest3) {
int argc = arraysize(kFlags);
char **argv = const_cast<char **>(kFlags);
ParseCommandLineFlags(kFlags[0], &argc, &argv);
- EXPECT_TRUE(FLAGS_bool_f);
- EXPECT_EQ(800, FLAGS_int32_f);
+ EXPECT_TRUE(absl::GetFlag(FLAGS_bool_f));
+ EXPECT_EQ(800, absl::GetFlag(FLAGS_int32_f));
EXPECT_EQ(1, argc);
}
@@ -129,7 +118,7 @@ TEST(FlagsTest, ParseCommandLineFlagsEmptyStringArgs) {
char **argv = const_cast<char **>(kFlags);
ParseCommandLineFlags(kFlags[0], &argc, &argv);
EXPECT_EQ(1, argc);
- EXPECT_EQ("", FLAGS_string_f);
+ EXPECT_EQ("", absl::GetFlag(FLAGS_string_f));
}
TEST(FlagsTest, ParseCommandLineFlagsEmptyBoolArgs) {
@@ -138,7 +127,7 @@ TEST(FlagsTest, ParseCommandLineFlagsEmptyBoolArgs) {
char **argv = const_cast<char **>(kFlags);
ParseCommandLineFlags(kFlags[0], &argc, &argv);
EXPECT_EQ(1, argc);
- EXPECT_TRUE(FLAGS_bool_f);
+ EXPECT_TRUE(absl::GetFlag(FLAGS_bool_f));
}
TEST(FlagsTest, ParseCommandLineFlagsEmptyIntArgs) {
@@ -147,5 +136,4 @@ TEST(FlagsTest, ParseCommandLineFlagsEmptyIntArgs) {
char **argv = const_cast<char **>(kFlags);
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), );
}
-} // namespace flags
-} // namespace sentencepiece
+} // namespace absl
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index bceba2c..cb669e7 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -866,12 +866,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
{
auto output = filesystem::NewWritableFile(
- util::JoinPath(FLAGS_test_tmpdir, "model"), true);
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model"), true);
output->Write(model_proto.SerializeAsString());
}
SentencePieceProcessor sp;
- EXPECT_TRUE(sp.Load(util::JoinPath(FLAGS_test_tmpdir, "model")).ok());
+ EXPECT_TRUE(
+ sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model")).ok());
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
@@ -1343,10 +1344,10 @@ TEST(SentencePieceProcessorTest, VocabularyTest) {
auto GetInlineFilename = [](const std::string content) {
{
auto out = filesystem::NewWritableFile(
- util::JoinPath(FLAGS_test_tmpdir, "vocab.txt"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt"));
out->Write(content);
}
- return util::JoinPath(FLAGS_test_tmpdir, "vocab.txt");
+ return util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt");
};
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 6f40865..f2b5050 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -22,6 +22,7 @@
#include "normalizer.h"
#include "sentencepiece_trainer.h"
#include "spec_parser.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_split.h"
@@ -30,6 +31,8 @@
#include "trainer_factory.h"
#include "util.h"
+ABSL_DECLARE_FLAG(int, minloglevel);
+
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nmt_nfkc";
@@ -148,7 +151,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
} else if (key == "minloglevel") {
int v = 0;
CHECK_OR_RETURN(absl::SimpleAtoi(value, &v));
- flags::SetMinLogLevel(v);
+ absl::SetFlag(&FLAGS_minloglevel, v);
continue;
}
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index b78b1d2..9c5614f 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -49,8 +49,10 @@ void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer,
}
TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
- const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestData);
- const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
+ const std::string input =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
+ const std::string model =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
ASSERT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--input=", input, " --model_prefix=", model,
@@ -114,8 +116,10 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
size_t idx_ = 0;
};
- const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestData);
- const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
+ const std::string input =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
+ const std::string model =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
std::vector<std::string> sentences;
{
@@ -135,9 +139,12 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
}
TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
- std::string input = util::JoinPath(FLAGS_test_srcdir, kTestData);
- std::string rule = util::JoinPath(FLAGS_test_srcdir, kNfkcTestData);
- const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
+ std::string input =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
+ std::string rule =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kNfkcTestData);
+ const std::string model =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
EXPECT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--input=", input, " --model_prefix=", model,
@@ -148,12 +155,14 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
}
TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) {
- const std::string input_file = util::JoinPath(FLAGS_test_srcdir, kTestDataJa);
- const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
+ const std::string input_file =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestDataJa);
+ const std::string model =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
const std::string norm_rule_tsv =
- util::JoinPath(FLAGS_test_srcdir, kIdsNormTsv);
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kIdsNormTsv);
const std::string denorm_rule_tsv =
- util::JoinPath(FLAGS_test_srcdir, kIdsDenormTsv);
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kIdsDenormTsv);
EXPECT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat("--input=", input_file, " --model_prefix=", model,
@@ -175,8 +184,10 @@ TEST(SentencePieceTrainerTest, TrainErrorTest) {
TEST(SentencePieceTrainerTest, TrainTest) {
TrainerSpec trainer_spec;
- trainer_spec.add_input(util::JoinPath(FLAGS_test_srcdir, kTestData));
- trainer_spec.set_model_prefix(util::JoinPath(FLAGS_test_tmpdir, "m"));
+ trainer_spec.add_input(
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData));
+ trainer_spec.set_model_prefix(
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m"));
trainer_spec.set_vocab_size(1000);
NormalizerSpec normalizer_spec;
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec, normalizer_spec).ok());
diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc
index a65e615..7284eb8 100644
--- a/src/spm_decode_main.cc
+++ b/src/spm_decode_main.cc
@@ -19,44 +19,45 @@
#include "builtin_pb/sentencepiece.pb.h"
#include "common.h"
#include "filesystem.h"
-#include "flags.h"
+#include "init.h"
#include "sentencepiece_processor.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/str_split.h"
#include "util.h"
-DEFINE_string(model, "", "model file name");
-DEFINE_string(input, "", "input filename");
-DEFINE_string(output, "", "output filename");
-DEFINE_string(input_format, "piece", "choose from piece or id");
-DEFINE_string(output_format, "string", "choose from string or proto");
-DEFINE_string(extra_options, "",
- "':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
+ABSL_FLAG(std::string, model, "", "model file name");
+ABSL_FLAG(std::string, input, "", "input filename");
+ABSL_FLAG(std::string, output, "", "output filename");
+ABSL_FLAG(std::string, input_format, "piece", "choose from piece or id");
+ABSL_FLAG(std::string, output_format, "string", "choose from string or proto");
+ABSL_FLAG(std::string, extra_options, "",
+ "':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
int main(int argc, char *argv[]) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
std::vector<std::string> rest_args;
- if (FLAGS_input.empty()) {
+ if (absl::GetFlag(FLAGS_input).empty()) {
for (int i = 1; i < argc; ++i) {
rest_args.push_back(std::string(argv[i]));
}
} else {
- rest_args.push_back(FLAGS_input);
+ rest_args.push_back(absl::GetFlag(FLAGS_input));
}
- CHECK(!FLAGS_model.empty());
+ if (rest_args.empty())
+ rest_args.push_back(""); // empty means that reading from stdin.
+
+ CHECK(!absl::GetFlag(FLAGS_model).empty());
sentencepiece::SentencePieceProcessor sp;
- CHECK_OK(sp.Load(FLAGS_model));
- CHECK_OK(sp.SetDecodeExtraOptions(FLAGS_extra_options));
+ CHECK_OK(sp.Load(absl::GetFlag(FLAGS_model)));
+ CHECK_OK(sp.SetDecodeExtraOptions(absl::GetFlag(FLAGS_extra_options)));
- auto output = sentencepiece::filesystem::NewWritableFile(FLAGS_output);
+ auto output =
+ sentencepiece::filesystem::NewWritableFile(absl::GetFlag(FLAGS_output));
CHECK_OK(output->status());
- if (rest_args.empty()) {
- rest_args.push_back(""); // empty means that reading from stdin.
- }
-
std::string detok, line;
sentencepiece::SentencePieceText spt;
std::function<void(const std::vector<std::string> &pieces)> process;
@@ -69,34 +70,36 @@ int main(int argc, char *argv[]) {
return ids;
};
- if (FLAGS_input_format == "piece") {
- if (FLAGS_output_format == "string") {
+ if (absl::GetFlag(FLAGS_input_format) == "piece") {
+ if (absl::GetFlag(FLAGS_output_format) == "string") {
process = [&](const std::vector<std::string> &pieces) {
CHECK_OK(sp.Decode(pieces, &detok));
output->WriteLine(detok);
};
- } else if (FLAGS_output_format == "proto") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "proto") {
process = [&](const std::vector<std::string> &pieces) {
CHECK_OK(sp.Decode(pieces, &spt));
};
} else {
- LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
+ LOG(FATAL) << "Unknown output format: "
+ << absl::GetFlag(FLAGS_output_format);
}
- } else if (FLAGS_input_format == "id") {
- if (FLAGS_output_format == "string") {
+ } else if (absl::GetFlag(FLAGS_input_format) == "id") {
+ if (absl::GetFlag(FLAGS_output_format) == "string") {
process = [&](const std::vector<std::string> &pieces) {
CHECK_OK(sp.Decode(ToIds(pieces), &detok));
output->WriteLine(detok);
};
- } else if (FLAGS_output_format == "proto") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "proto") {
process = [&](const std::vector<std::string> &pieces) {
CHECK_OK(sp.Decode(ToIds(pieces), &spt));
};
} else {
- LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
+ LOG(FATAL) << "Unknown output format: "
+ << absl::GetFlag(FLAGS_output_format);
}
} else {
- LOG(FATAL) << "Unknown input format: " << FLAGS_input_format;
+ LOG(FATAL) << "Unknown input format: " << absl::GetFlag(FLAGS_input_format);
}
for (const auto &filename : rest_args) {
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc
index 85480c6..572cba5 100644
--- a/src/spm_encode_main.cc
+++ b/src/spm_encode_main.cc
@@ -20,62 +20,64 @@
#include "builtin_pb/sentencepiece.pb.h"
#include "common.h"
#include "filesystem.h"
-#include "flags.h"
+#include "init.h"
#include "sentencepiece_processor.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_join.h"
#include "trainer_interface.h"
-DEFINE_string(model, "", "model file name");
-DEFINE_string(
- output_format, "piece",
+ABSL_FLAG(std::string, model, "", "model file name");
+ABSL_FLAG(
+ std::string, output_format, "piece",
"choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto");
-DEFINE_string(input, "", "input filename");
-DEFINE_string(output, "", "output filename");
-DEFINE_string(extra_options, "",
- "':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
-DEFINE_int32(nbest_size, 10, "NBest size");
-DEFINE_double(alpha, 0.5, "Smoothing parameter for sampling mode.");
+ABSL_FLAG(std::string, input, "", "input filename");
+ABSL_FLAG(std::string, output, "", "output filename");
+ABSL_FLAG(std::string, extra_options, "",
+ "':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
+ABSL_FLAG(int32, nbest_size, 10, "NBest size");
+ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
// Piece restriction with vocabulary file.
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
-DEFINE_string(vocabulary, "",
- "Restrict the vocabulary. The encoder only emits the "
- "tokens in \"vocabulary\" file");
-DEFINE_int32(vocabulary_threshold, 0,
- "Words with frequency < threshold will be treated as OOV");
-DEFINE_bool(generate_vocabulary, false,
- "Generates vocabulary file instead of segmentation");
+ABSL_FLAG(std::string, vocabulary, "",
+ "Restrict the vocabulary. The encoder only emits the "
+ "tokens in \"vocabulary\" file");
+ABSL_FLAG(int32, vocabulary_threshold, 0,
+ "Words with frequency < threshold will be treated as OOV");
+ABSL_FLAG(int32, generate_vocabulary, false,
+ "Generates vocabulary file instead of segmentation");
int main(int argc, char *argv[]) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
std::vector<std::string> rest_args;
- if (FLAGS_input.empty()) {
+ if (absl::GetFlag(FLAGS_input).empty()) {
for (int i = 1; i < argc; ++i) {
rest_args.push_back(std::string(argv[i]));
}
} else {
- rest_args.push_back(FLAGS_input);
+ rest_args.push_back(absl::GetFlag(FLAGS_input));
}
- CHECK(!FLAGS_model.empty());
+ if (rest_args.empty())
+ rest_args.push_back(""); // empty means that reading from stdin.
+
+ CHECK(!absl::GetFlag(FLAGS_model).empty());
sentencepiece::SentencePieceProcessor sp;
- CHECK_OK(sp.Load(FLAGS_model));
- CHECK_OK(sp.SetEncodeExtraOptions(FLAGS_extra_options));
+ CHECK_OK(sp.Load(absl::GetFlag(FLAGS_model)));
+ CHECK_OK(sp.SetEncodeExtraOptions(absl::GetFlag(FLAGS_extra_options)));
- if (!FLAGS_vocabulary.empty()) {
- CHECK_OK(sp.LoadVocabulary(FLAGS_vocabulary, FLAGS_vocabulary_threshold));
+ if (!absl::GetFlag(FLAGS_vocabulary).empty()) {
+ CHECK_OK(sp.LoadVocabulary(absl::GetFlag(FLAGS_vocabulary),
+ absl::GetFlag(FLAGS_vocabulary_threshold)));
}
- auto output = sentencepiece::filesystem::NewWritableFile(FLAGS_output);
+ auto output =
+ sentencepiece::filesystem::NewWritableFile(absl::GetFlag(FLAGS_output));
CHECK_OK(output->status());
- if (rest_args.empty()) {
- rest_args.push_back(""); // empty means that reading from stdin.
- }
-
std::string line;
std::vector<std::string> sps;
std::vector<int> ids;
@@ -86,7 +88,10 @@ int main(int argc, char *argv[]) {
sentencepiece::NBestSentencePieceText nbest_spt;
std::function<void(const std::string &line)> process;
- if (FLAGS_generate_vocabulary) {
+ const int nbest_size = absl::GetFlag(FLAGS_nbest_size);
+ const float alpha = absl::GetFlag(FLAGS_alpha);
+
+ if (absl::GetFlag(FLAGS_generate_vocabulary)) {
process = [&](const std::string &line) {
CHECK_OK(sp.Encode(line, &spt));
for (const auto &piece : spt.pieces()) {
@@ -94,52 +99,53 @@ int main(int argc, char *argv[]) {
vocab[piece.piece()]++;
}
};
- } else if (FLAGS_output_format == "piece") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "piece") {
process = [&](const std::string &line) {
CHECK_OK(sp.Encode(line, &sps));
output->WriteLine(absl::StrJoin(sps, " "));
};
- } else if (FLAGS_output_format == "id") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "id") {
process = [&](const std::string &line) {
CHECK_OK(sp.Encode(line, &ids));
output->WriteLine(absl::StrJoin(ids, " "));
};
- } else if (FLAGS_output_format == "proto") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "proto") {
process = [&](const std::string &line) { CHECK_OK(sp.Encode(line, &spt)); };
- } else if (FLAGS_output_format == "sample_piece") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "sample_piece") {
process = [&](const std::string &line) {
- CHECK_OK(sp.SampleEncode(line, FLAGS_nbest_size, FLAGS_alpha, &sps));
+ CHECK_OK(sp.SampleEncode(line, nbest_size, alpha, &sps));
output->WriteLine(absl::StrJoin(sps, " "));
};
- } else if (FLAGS_output_format == "sample_id") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "sample_id") {
process = [&](const std::string &line) {
- CHECK_OK(sp.SampleEncode(line, FLAGS_nbest_size, FLAGS_alpha, &ids));
+ CHECK_OK(sp.SampleEncode(line, nbest_size, alpha, &ids));
output->WriteLine(absl::StrJoin(ids, " "));
};
- } else if (FLAGS_output_format == "sample_proto") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "sample_proto") {
process = [&](const std::string &line) {
- CHECK_OK(sp.SampleEncode(line, FLAGS_nbest_size, FLAGS_alpha, &spt));
+ CHECK_OK(sp.SampleEncode(line, nbest_size, alpha, &spt));
};
- } else if (FLAGS_output_format == "nbest_piece") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "nbest_piece") {
process = [&](const std::string &line) {
- CHECK_OK(sp.NBestEncode(line, FLAGS_nbest_size, &nbest_sps));
+ CHECK_OK(sp.NBestEncode(line, nbest_size, &nbest_sps));
for (const auto &result : nbest_sps) {
output->WriteLine(absl::StrJoin(result, " "));
}
};
- } else if (FLAGS_output_format == "nbest_id") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "nbest_id") {
process = [&](const std::string &line) {
- CHECK_OK(sp.NBestEncode(line, FLAGS_nbest_size, &nbest_ids));
+ CHECK_OK(sp.NBestEncode(line, nbest_size, &nbest_ids));
for (const auto &result : nbest_ids) {
output->WriteLine(absl::StrJoin(result, " "));
}
};
- } else if (FLAGS_output_format == "nbest_proto") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "nbest_proto") {
process = [&](const std::string &line) {
- CHECK_OK(sp.NBestEncode(line, FLAGS_nbest_size, &nbest_spt));
+ CHECK_OK(sp.NBestEncode(line, nbest_size, &nbest_spt));
};
} else {
- LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
+ LOG(FATAL) << "Unknown output format: "
+ << absl::GetFlag(FLAGS_output_format);
}
for (const auto &filename : rest_args) {
@@ -150,7 +156,7 @@ int main(int argc, char *argv[]) {
}
}
- if (FLAGS_generate_vocabulary) {
+ if (absl::GetFlag(FLAGS_generate_vocabulary)) {
for (const auto &it : sentencepiece::Sorted(vocab)) {
output->WriteLine(it.first + "\t" +
sentencepiece::string_util::SimpleItoa(it.second));
diff --git a/src/spm_export_vocab_main.cc b/src/spm_export_vocab_main.cc
index 729faf2..9b98f01 100644
--- a/src/spm_export_vocab_main.cc
+++ b/src/spm_export_vocab_main.cc
@@ -18,37 +18,41 @@
#include "builtin_pb/sentencepiece_model.pb.h"
#include "common.h"
#include "filesystem.h"
-#include "flags.h"
+#include "init.h"
#include "sentencepiece_processor.h"
+#include "third_party/absl/flags/flag.h"
-DEFINE_string(output, "", "Output filename");
-DEFINE_string(model, "", "input model file name");
-DEFINE_string(output_format, "vocab",
- "output format. choose from vocab or syms. vocab outputs pieces "
- "and scores, syms outputs pieces and indices.");
+ABSL_FLAG(std::string, output, "", "Output filename");
+ABSL_FLAG(std::string, model, "", "input model file name");
+ABSL_FLAG(std::string, output_format, "vocab",
+ "output format. choose from vocab or syms. vocab outputs pieces "
+ "and scores, syms outputs pieces and indices.");
int main(int argc, char *argv[]) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+
sentencepiece::SentencePieceProcessor sp;
- CHECK_OK(sp.Load(FLAGS_model));
+ CHECK_OK(sp.Load(absl::GetFlag(FLAGS_model)));
- auto output = sentencepiece::filesystem::NewWritableFile(FLAGS_output);
+ auto output =
+ sentencepiece::filesystem::NewWritableFile(absl::GetFlag(FLAGS_output));
CHECK_OK(output->status());
- if (FLAGS_output_format == "vocab") {
+ if (absl::GetFlag(FLAGS_output_format) == "vocab") {
for (const auto &piece : sp.model_proto().pieces()) {
std::ostringstream os;
os << piece.piece() << "\t" << piece.score();
output->WriteLine(os.str());
}
- } else if (FLAGS_output_format == "syms") {
+ } else if (absl::GetFlag(FLAGS_output_format) == "syms") {
for (int i = 0; i < sp.model_proto().pieces_size(); i++) {
std::ostringstream os;
os << sp.model_proto().pieces(i).piece() << "\t" << i;
output->WriteLine(os.str());
}
} else {
- LOG(FATAL) << "Unsupported output format: " << FLAGS_output_format;
+ LOG(FATAL) << "Unsupported output format: "
+ << absl::GetFlag(FLAGS_output_format);
}
return 0;
diff --git a/src/spm_normalize_main.cc b/src/spm_normalize_main.cc
index 4b6f5bc..244b974 100644
--- a/src/spm_normalize_main.cc
+++ b/src/spm_normalize_main.cc
@@ -17,24 +17,26 @@
#include "builtin_pb/sentencepiece_model.pb.h"
#include "common.h"
#include "filesystem.h"
-#include "flags.h"
+#include "init.h"
#include "normalizer.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
+#include "third_party/absl/flags/flag.h"
-DEFINE_string(model, "", "Model file name");
-DEFINE_bool(use_internal_normalization, false,
- "Use NormalizerSpec \"as-is\" to run the normalizer "
- "for SentencePiece segmentation");
-DEFINE_string(normalization_rule_name, "",
- "Normalization rule name. "
- "Choose from nfkc or identity");
-DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
-DEFINE_bool(remove_extra_whitespaces, true, "Remove extra whitespaces");
-DEFINE_bool(decompile, false,
- "Decompile compiled charamap and output it as TSV.");
-DEFINE_string(input, "", "Input filename");
-DEFINE_string(output, "", "Output filename");
+ABSL_FLAG(std::string, model, "", "Model file name");
+ABSL_FLAG(bool, use_internal_normalization, false,
+ "Use NormalizerSpec \"as-is\" to run the normalizer "
+ "for SentencePiece segmentation");
+ABSL_FLAG(std::string, normalization_rule_name, "",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
+ABSL_FLAG(std::string, normalization_rule_tsv, "",
+ "Normalization rule TSV file. ");
+ABSL_FLAG(bool, remove_extra_whitespaces, true, "Remove extra whitespaces");
+ABSL_FLAG(bool, decompile, false,
+ "Decompile compiled charamap and output it as TSV.");
+ABSL_FLAG(std::string, input, "", "Input filename");
+ABSL_FLAG(std::string, output, "", "Output filename");
using sentencepiece::ModelProto;
using sentencepiece::NormalizerSpec;
@@ -44,29 +46,30 @@ using sentencepiece::normalizer::Builder;
using sentencepiece::normalizer::Normalizer;
int main(int argc, char *argv[]) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
-
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
std::vector<std::string> rest_args;
- if (FLAGS_input.empty()) {
+
+ if (absl::GetFlag(FLAGS_input).empty()) {
for (int i = 1; i < argc; ++i) {
rest_args.push_back(std::string(argv[i]));
}
} else {
- rest_args.push_back(FLAGS_input);
+ rest_args.push_back(absl::GetFlag(FLAGS_input));
}
NormalizerSpec spec;
- if (!FLAGS_model.empty()) {
+ if (!absl::GetFlag(FLAGS_model).empty()) {
ModelProto model_proto;
SentencePieceProcessor sp;
- CHECK_OK(sp.Load(FLAGS_model));
+ CHECK_OK(sp.Load(absl::GetFlag(FLAGS_model)));
spec = sp.model_proto().normalizer_spec();
- } else if (!FLAGS_normalization_rule_tsv.empty()) {
- spec.set_normalization_rule_tsv(FLAGS_normalization_rule_tsv);
+ } else if (!absl::GetFlag(FLAGS_normalization_rule_tsv).empty()) {
+ spec.set_normalization_rule_tsv(
+ absl::GetFlag(FLAGS_normalization_rule_tsv));
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
- } else if (!FLAGS_normalization_rule_name.empty()) {
- spec.set_name(FLAGS_normalization_rule_name);
+ } else if (!absl::GetFlag(FLAGS_normalization_rule_name).empty()) {
+ spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name));
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
} else {
LOG(FATAL) << "Sets --model, normalization_rule_tsv, or "
@@ -74,20 +77,22 @@ int main(int argc, char *argv[]) {
}
// Uses the normalizer spec encoded in the model_pb.
- if (!FLAGS_use_internal_normalization) {
+ if (!absl::GetFlag(FLAGS_use_internal_normalization)) {
spec.set_add_dummy_prefix(false); // do not add dummy prefix.
spec.set_escape_whitespaces(false); // do not output meta symbol.
- spec.set_remove_extra_whitespaces(FLAGS_remove_extra_whitespaces);
+ spec.set_remove_extra_whitespaces(
+ absl::GetFlag(FLAGS_remove_extra_whitespaces));
}
- if (FLAGS_decompile) {
+ if (absl::GetFlag(FLAGS_decompile)) {
Builder::CharsMap chars_map;
CHECK_OK(
Builder::DecompileCharsMap(spec.precompiled_charsmap(), &chars_map));
- CHECK_OK(Builder::SaveCharsMap(FLAGS_output, chars_map));
+ CHECK_OK(Builder::SaveCharsMap(absl::GetFlag(FLAGS_output), chars_map));
} else {
const Normalizer normalizer(spec);
- auto output = sentencepiece::filesystem::NewWritableFile(FLAGS_output);
+ auto output =
+ sentencepiece::filesystem::NewWritableFile(absl::GetFlag(FLAGS_output));
CHECK_OK(output->status());
if (rest_args.empty()) {
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index a2ec3a7..6d990e0 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -15,8 +15,9 @@
#include <map>
#include "builtin_pb/sentencepiece_model.pb.h"
-#include "flags.h"
+#include "init.h"
#include "sentencepiece_trainer.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/ascii.h"
#include "third_party/absl/strings/str_split.h"
#include "util.h"
@@ -29,121 +30,129 @@ static sentencepiece::TrainerSpec kDefaultTrainerSpec;
static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
-DEFINE_string(input, "", "comma separated list of input sentences");
-DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
- "Input format. Supported format is `text` or `tsv`.");
-DEFINE_string(model_prefix, "", "output model prefix");
-DEFINE_string(model_type, "unigram",
- "model algorithm: unigram, bpe, word or char");
-DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
-DEFINE_string(accept_language, "",
- "comma-separated list of languages this model can accept");
-DEFINE_int32(self_test_sample_size, kDefaultTrainerSpec.self_test_sample_size(),
- "the size of self test samples");
-DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
- "character coverage to determine the minimum symbols");
-DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
- "maximum size of sentences the trainer loads");
-DEFINE_bool(shuffle_input_sentence,
- kDefaultTrainerSpec.shuffle_input_sentence(),
- "Randomly sample input sentences in advance. Valid when "
- "--input_sentence_size > 0");
-DEFINE_int32(seed_sentencepiece_size,
- kDefaultTrainerSpec.seed_sentencepiece_size(),
- "the size of seed sentencepieces");
-DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
- "Keeps top shrinking_factor pieces with respect to the loss");
-DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
- "number of threads for training");
-DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
- "number of EM sub-iterations");
-DEFINE_int32(max_sentencepiece_length,
- kDefaultTrainerSpec.max_sentencepiece_length(),
- "maximum length of sentence piece");
-DEFINE_int32(max_sentence_length, kDefaultTrainerSpec.max_sentence_length(),
- "maximum length of sentence in byte");
-DEFINE_bool(split_by_unicode_script,
- kDefaultTrainerSpec.split_by_unicode_script(),
- "use Unicode script to split sentence pieces");
-DEFINE_bool(split_by_number, kDefaultTrainerSpec.split_by_number(),
- "split tokens by numbers (0-9)");
-DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
- "use a white space to split sentence pieces");
-DEFINE_bool(split_digits, kDefaultTrainerSpec.split_digits(),
- "split all digits (0-9) into separate pieces");
-DEFINE_bool(treat_whitespace_as_suffix,
- kDefaultTrainerSpec.treat_whitespace_as_suffix(),
- "treat whitespace marker as suffix instead of prefix.");
-DEFINE_string(control_symbols, "", "comma separated list of control symbols");
-DEFINE_string(user_defined_symbols, "",
- "comma separated list of user defined symbols");
-DEFINE_string(required_chars, "",
- "UTF8 characters in this flag are always used in the character "
- "set regardless of --character_coverage");
-DEFINE_bool(byte_fallback, kDefaultTrainerSpec.byte_fallback(),
- "decompose unknown pieces into UTF-8 byte pieces");
-DEFINE_bool(vocabulary_output_piece_score,
- kDefaultTrainerSpec.vocabulary_output_piece_score(),
- "Define score in vocab file");
-DEFINE_string(normalization_rule_name, "nmt_nfkc",
- "Normalization rule name. "
- "Choose from nfkc or identity");
-DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
-DEFINE_string(denormalization_rule_tsv, "", "Denormalization rule TSV file.");
-DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
- "Add dummy whitespace at the beginning of text");
-DEFINE_bool(remove_extra_whitespaces,
- kDefaultNormalizerSpec.remove_extra_whitespaces(),
- "Removes leading, trailing, and "
- "duplicate internal whitespace");
-DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
- "If set to false, --vocab_size is considered as a soft limit.");
-DEFINE_bool(use_all_vocab, kDefaultTrainerSpec.use_all_vocab(),
- "If set to true, use all tokens as vocab. "
- "Valid for word/char models.");
-DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
-DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
- "Override BOS (<s>) id. Set -1 to disable BOS.");
-DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
- "Override EOS (</s>) id. Set -1 to disable EOS.");
-DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
- "Override PAD (<pad>) id. Set -1 to disable PAD.");
-DEFINE_string(unk_piece, kDefaultTrainerSpec.unk_piece(),
- "Override UNK (<unk>) piece.");
-DEFINE_string(bos_piece, kDefaultTrainerSpec.bos_piece(),
- "Override BOS (<s>) piece.");
-DEFINE_string(eos_piece, kDefaultTrainerSpec.eos_piece(),
- "Override EOS (</s>) piece.");
-DEFINE_string(pad_piece, kDefaultTrainerSpec.pad_piece(),
- "Override PAD (<pad>) piece.");
-DEFINE_string(unk_surface, kDefaultTrainerSpec.unk_surface(),
- "Dummy surface string for <unk>. In decoding <unk> is decoded to "
- "`unk_surface`.");
-DEFINE_bool(train_extremely_large_corpus,
- kDefaultTrainerSpec.train_extremely_large_corpus(),
- "Increase bit depth for unigram tokenization.");
+ABSL_FLAG(std::string, input, "", "comma separated list of input sentences");
+ABSL_FLAG(std::string, input_format, kDefaultTrainerSpec.input_format(),
+ "Input format. Supported format is `text` or `tsv`.");
+ABSL_FLAG(std::string, model_prefix, "", "output model prefix");
+ABSL_FLAG(std::string, model_type, "unigram",
+ "model algorithm: unigram, bpe, word or char");
+ABSL_FLAG(int32, vocab_size, kDefaultTrainerSpec.vocab_size(),
+ "vocabulary size");
+ABSL_FLAG(std::string, accept_language, "",
+ "comma-separated list of languages this model can accept");
+ABSL_FLAG(int32, self_test_sample_size,
+ kDefaultTrainerSpec.self_test_sample_size(),
+ "the size of self test samples");
+ABSL_FLAG(double, character_coverage, kDefaultTrainerSpec.character_coverage(),
+ "character coverage to determine the minimum symbols");
+ABSL_FLAG(int32, input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
+ "maximum size of sentences the trainer loads");
+ABSL_FLAG(bool, shuffle_input_sentence,
+ kDefaultTrainerSpec.shuffle_input_sentence(),
+ "Randomly sample input sentences in advance. Valid when "
+ "--input_sentence_size > 0");
+ABSL_FLAG(int32, seed_sentencepiece_size,
+ kDefaultTrainerSpec.seed_sentencepiece_size(),
+ "the size of seed sentencepieces");
+ABSL_FLAG(double, shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
+ "Keeps top shrinking_factor pieces with respect to the loss");
+ABSL_FLAG(int32, num_threads, kDefaultTrainerSpec.num_threads(),
+ "number of threads for training");
+ABSL_FLAG(int32, num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
+ "number of EM sub-iterations");
+ABSL_FLAG(int32, max_sentencepiece_length,
+ kDefaultTrainerSpec.max_sentencepiece_length(),
+ "maximum length of sentence piece");
+ABSL_FLAG(int32, max_sentence_length, kDefaultTrainerSpec.max_sentence_length(),
+ "maximum length of sentence in byte");
+ABSL_FLAG(bool, split_by_unicode_script,
+ kDefaultTrainerSpec.split_by_unicode_script(),
+ "use Unicode script to split sentence pieces");
+ABSL_FLAG(bool, split_by_number, kDefaultTrainerSpec.split_by_number(),
+ "split tokens by numbers (0-9)");
+ABSL_FLAG(bool, split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
+ "use a white space to split sentence pieces");
+ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(),
+ "split all digits (0-9) into separate pieces");
+ABSL_FLAG(bool, treat_whitespace_as_suffix,
+ kDefaultTrainerSpec.treat_whitespace_as_suffix(),
+ "treat whitespace marker as suffix instead of prefix.");
+ABSL_FLAG(std::string, control_symbols, "",
+ "comma separated list of control symbols");
+ABSL_FLAG(std::string, user_defined_symbols, "",
+ "comma separated list of user defined symbols");
+ABSL_FLAG(std::string, required_chars, "",
+ "UTF8 characters in this flag are always used in the character "
+ "set regardless of --character_coverage");
+ABSL_FLAG(bool, byte_fallback, kDefaultTrainerSpec.byte_fallback(),
+ "decompose unknown pieces into UTF-8 byte pieces");
+ABSL_FLAG(bool, vocabulary_output_piece_score,
+ kDefaultTrainerSpec.vocabulary_output_piece_score(),
+ "Define score in vocab file");
+ABSL_FLAG(std::string, normalization_rule_name, "nmt_nfkc",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
+ABSL_FLAG(std::string, normalization_rule_tsv, "",
+ "Normalization rule TSV file. ");
+ABSL_FLAG(std::string, denormalization_rule_tsv, "",
+ "Denormalization rule TSV file.");
+ABSL_FLAG(bool, add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
+ "Add dummy whitespace at the beginning of text");
+ABSL_FLAG(bool, remove_extra_whitespaces,
+ kDefaultNormalizerSpec.remove_extra_whitespaces(),
+ "Removes leading, trailing, and "
+ "duplicate internal whitespace");
+ABSL_FLAG(bool, hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
+ "If set to false, --vocab_size is considered as a soft limit.");
+ABSL_FLAG(bool, use_all_vocab, kDefaultTrainerSpec.use_all_vocab(),
+ "If set to true, use all tokens as vocab. "
+ "Valid for word/char models.");
+ABSL_FLAG(int32, unk_id, kDefaultTrainerSpec.unk_id(),
+ "Override UNK (<unk>) id.");
+ABSL_FLAG(int32, bos_id, kDefaultTrainerSpec.bos_id(),
+ "Override BOS (<s>) id. Set -1 to disable BOS.");
+ABSL_FLAG(int32, eos_id, kDefaultTrainerSpec.eos_id(),
+ "Override EOS (</s>) id. Set -1 to disable EOS.");
+ABSL_FLAG(int32, pad_id, kDefaultTrainerSpec.pad_id(),
+ "Override PAD (<pad>) id. Set -1 to disable PAD.");
+ABSL_FLAG(std::string, unk_piece, kDefaultTrainerSpec.unk_piece(),
+ "Override UNK (<unk>) piece.");
+ABSL_FLAG(std::string, bos_piece, kDefaultTrainerSpec.bos_piece(),
+ "Override BOS (<s>) piece.");
+ABSL_FLAG(std::string, eos_piece, kDefaultTrainerSpec.eos_piece(),
+ "Override EOS (</s>) piece.");
+ABSL_FLAG(std::string, pad_piece, kDefaultTrainerSpec.pad_piece(),
+ "Override PAD (<pad>) piece.");
+ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
+ "Dummy surface string for <unk>. In decoding <unk> is decoded to "
+ "`unk_surface`.");
+ABSL_FLAG(bool, train_extremely_large_corpus,
+ kDefaultTrainerSpec.train_extremely_large_corpus(),
+ "Increase bit depth for unigram tokenization.");
int main(int argc, char *argv[]) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
sentencepiece::TrainerSpec trainer_spec;
sentencepiece::NormalizerSpec normalizer_spec;
NormalizerSpec denormalizer_spec;
- CHECK(!FLAGS_input.empty());
- CHECK(!FLAGS_model_prefix.empty());
+ CHECK(!absl::GetFlag(FLAGS_input).empty());
+ CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
// Populates the value from flags to spec.
-#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
+#define SetTrainerSpecFromFlag(name) \
+ trainer_spec.set_##name(absl::GetFlag(FLAGS_##name));
#define SetNormalizerSpecFromFlag(name) \
- normalizer_spec.set_##name(FLAGS_##name);
-
-#define SetRepeatedTrainerSpecFromFlag(name) \
- if (!FLAGS_##name.empty()) { \
- for (const auto &v : sentencepiece::util::StrSplitAsCSV(FLAGS_##name)) { \
- trainer_spec.add_##name(v); \
- } \
+ normalizer_spec.set_##name(absl::GetFlag(FLAGS_##name));
+
+#define SetRepeatedTrainerSpecFromFlag(name) \
+ if (!absl::GetFlag(FLAGS_##name).empty()) { \
+ for (const auto &v : \
+ sentencepiece::util::StrSplitAsCSV(absl::GetFlag(FLAGS_##name))) { \
+ trainer_spec.add_##name(v); \
+ } \
}
SetRepeatedTrainerSpecFromFlag(input);
@@ -185,21 +194,21 @@ int main(int argc, char *argv[]) {
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
SetTrainerSpecFromFlag(train_extremely_large_corpus);
- normalizer_spec.set_name(FLAGS_normalization_rule_name);
+ normalizer_spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name));
SetNormalizerSpecFromFlag(normalization_rule_tsv);
SetNormalizerSpecFromFlag(add_dummy_prefix);
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
- if (!FLAGS_denormalization_rule_tsv.empty()) {
+ if (!absl::GetFlag(FLAGS_denormalization_rule_tsv).empty()) {
denormalizer_spec.set_normalization_rule_tsv(
- FLAGS_denormalization_rule_tsv);
+ absl::GetFlag(FLAGS_denormalization_rule_tsv));
denormalizer_spec.set_add_dummy_prefix(false);
denormalizer_spec.set_remove_extra_whitespaces(false);
denormalizer_spec.set_escape_whitespaces(false);
}
CHECK_OK(sentencepiece::SentencePieceTrainer::PopulateModelTypeFromString(
- FLAGS_model_type, &trainer_spec));
+ absl::GetFlag(FLAGS_model_type), &trainer_spec));
CHECK_OK(sentencepiece::SentencePieceTrainer::Train(
trainer_spec, normalizer_spec, denormalizer_spec));
diff --git a/src/test_main.cc b/src/test_main.cc
index 354554f..b3170e2 100644
--- a/src/test_main.cc
+++ b/src/test_main.cc
@@ -12,19 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "flags.h"
+#include "init.h"
#include "testharness.h"
#ifdef OS_WIN
-DEFINE_string(test_srcdir, "..\\data", "Data directory.");
+ABSL_FLAG(std::string, test_srcdir, "..\\data", "Data directory.");
#else
-DEFINE_string(test_srcdir, "../data", "Data directory.");
+ABSL_FLAG(std::string, test_srcdir, "../data", "Data directory.");
#endif
-DEFINE_string(test_tmpdir, "test_tmp", "Temporary directory.");
+ABSL_FLAG(std::string, test_tmpdir, "test_tmp", "Temporary directory.");
int main(int argc, char **argv) {
- sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true);
+ sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
sentencepiece::test::RunAllTests();
return 0;
}
diff --git a/src/testharness.cc b/src/testharness.cc
index 76746df..e852d3f 100644
--- a/src/testharness.cc
+++ b/src/testharness.cc
@@ -26,6 +26,8 @@
#include <vector>
#include "common.h"
+#include "init.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/str_cat.h"
#include "util.h"
@@ -56,9 +58,9 @@ bool RegisterTest(const char *base, const char *name, void (*func)()) {
int RunAllTests() {
int num = 0;
#ifdef OS_WIN
- _mkdir(FLAGS_test_tmpdir.c_str());
+ _mkdir(absl::GetFlag(FLAGS_test_tmpdir).c_str());
#else
- mkdir(FLAGS_test_tmpdir.c_str(), S_IRUSR | S_IWUSR | S_IXUSR);
+ mkdir(absl::GetFlag(FLAGS_test_tmpdir).c_str(), S_IRUSR | S_IWUSR | S_IXUSR);
#endif
if (tests == nullptr) {
diff --git a/src/testharness.h b/src/testharness.h
index 6962649..193ec74 100644
--- a/src/testharness.h
+++ b/src/testharness.h
@@ -21,11 +21,12 @@
#include <string>
#include "common.h"
-#include "flags.h"
+#include "init.h"
+#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/string_view.h"
-DECLARE_string(test_tmpdir);
-DECLARE_string(test_srcdir);
+ABSL_DECLARE_FLAG(std::string, test_tmpdir);
+ABSL_DECLARE_FLAG(std::string, test_srcdir);
namespace sentencepiece {
namespace test {
@@ -130,11 +131,11 @@ class Tester {
#define EXPECT_OK(c) EXPECT_EQ(c, ::sentencepiece::util::OkStatus())
#define EXPECT_NOT_OK(c) EXPECT_NE(c, ::sentencepiece::util::OkStatus())
-#define EXPECT_DEATH(statement, condition) \
- { \
- error::SetTestCounter(1); \
- statement; \
- error::SetTestCounter(0); \
+#define EXPECT_DEATH(statement, condition) \
+ { \
+ sentencepiece::error::SetTestCounter(1); \
+ statement; \
+ sentencepiece::error::SetTestCounter(0); \
};
#define ASSERT_TRUE EXPECT_TRUE
diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc
index 7a88e9b..0144376 100644
--- a/src/trainer_interface_test.cc
+++ b/src/trainer_interface_test.cc
@@ -441,7 +441,8 @@ TEST(TrainerInterfaceTest, SerializeTest) {
}
TEST(TrainerInterfaceTest, CharactersTest) {
- const std::string input_file = util::JoinPath(FLAGS_test_tmpdir, "input");
+ const std::string input_file =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
{
auto output = filesystem::NewWritableFile(input_file);
// Make a single line with 50 "a", 49 "あ", and 1 "b".
@@ -507,8 +508,8 @@ TEST(TrainerInterfaceTest, MultiFileSentenceIteratorTest) {
std::vector<std::string> files;
std::vector<std::string> expected;
for (int i = 0; i < 10; ++i) {
- const std::string file =
- util::JoinPath(FLAGS_test_tmpdir, absl::StrCat("input", i));
+ const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
+ absl::StrCat("input", i));
auto output = filesystem::NewWritableFile(file);
int num_line = (rand() % 100) + 1;
for (int n = 0; n < num_line; ++n) {
@@ -529,8 +530,8 @@ TEST(TrainerInterfaceTest, MultiFileSentenceIteratorTest) {
TEST(TrainerInterfaceTest, MultiFileSentenceIteratorErrorTest) {
std::vector<std::string> files;
for (int i = 0; i < 10; ++i) {
- const std::string file =
- util::JoinPath(FLAGS_test_tmpdir, absl::StrCat("input_not_exist", i));
+ const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
+ absl::StrCat("input_not_exist", i));
files.push_back(file);
}
diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc
index f15d7f9..cca9936 100644
--- a/src/unigram_model_trainer_test.cc
+++ b/src/unigram_model_trainer_test.cc
@@ -38,12 +38,14 @@ TEST(UnigramTrainerTest, TrainerModelTest) {
static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";
TEST(UnigramTrainerTest, EndToEndTest) {
- const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestInputData);
+ const std::string input =
+ util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);
ASSERT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat(
- "--model_prefix=", util::JoinPath(FLAGS_test_tmpdir, "tmp_model"),
+ "--model_prefix=",
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
" --input=", input,
" --vocab_size=8000 --normalization_rule_name=identity",
" --model_type=unigram --user_defined_symbols=<user>",
@@ -51,8 +53,9 @@ TEST(UnigramTrainerTest, EndToEndTest) {
.ok());
SentencePieceProcessor sp;
- EXPECT_TRUE(
- sp.Load(util::JoinPath(FLAGS_test_tmpdir, "tmp_model.model")).ok());
+ EXPECT_TRUE(sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
+ "tmp_model.model"))
+ .ok());
EXPECT_EQ(8000, sp.GetPieceSize());
const int cid = sp.PieceToId("<ctrl>");
diff --git a/src/util_test.cc b/src/util_test.cc
index 022a0f0..71d006f 100644
--- a/src/util_test.cc
+++ b/src/util_test.cc
@@ -332,7 +332,7 @@ TEST(UtilTest, InputOutputBufferTest) {
{
auto output = filesystem::NewWritableFile(
- util::JoinPath(FLAGS_test_tmpdir, "test_file"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
for (size_t i = 0; i < kData.size(); ++i) {
output->WriteLine(kData[i]);
}
@@ -340,7 +340,7 @@ TEST(UtilTest, InputOutputBufferTest) {
{
auto input = filesystem::NewReadableFile(
- util::JoinPath(FLAGS_test_tmpdir, "test_file"));
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
std::string line;
for (size_t i = 0; i < kData.size(); ++i) {
EXPECT_TRUE(input->ReadLine(&line));
diff --git a/src/word_model_trainer_test.cc b/src/word_model_trainer_test.cc
index 061901d..c4a8bc6 100644
--- a/src/word_model_trainer_test.cc
+++ b/src/word_model_trainer_test.cc
@@ -31,8 +31,10 @@ namespace {
#define WS "\xE2\x96\x81"
std::string RunTrainer(const std::vector<std::string> &input, int size) {
- const std::string input_file = util::JoinPath(FLAGS_test_tmpdir, "input");
- const std::string model_prefix = util::JoinPath(FLAGS_test_tmpdir, "model");
+ const std::string input_file =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
+ const std::string model_prefix =
+ util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
diff --git a/third_party/absl/flags/flag.cc b/third_party/absl/flags/flag.cc
new file mode 100644
index 0000000..09ff78f
--- /dev/null
+++ b/third_party/absl/flags/flag.cc
@@ -0,0 +1,220 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#include "third_party/absl/flags/flag.h"
+
+#include <algorithm>
+#include <iostream>
+#include <map>
+#include <sstream>
+#include <string>
+
+#include "config.h"
+#include "src/common.h"
+#include "src/util.h"
+
+ABSL_FLAG(bool, help, false, "show help");
+ABSL_FLAG(bool, version, false, "show version");
+ABSL_FLAG(int, minloglevel, 0,
+ "Messages logged at a lower level than this don't actually get "
+ "logged anywhere");
+
+namespace absl {
+namespace internal {
+namespace {
+template <typename T>
+std::string to_str(const T &value) {
+ std::ostringstream os;
+ os << value;
+ return os.str();
+}
+
+template <>
+std::string to_str<bool>(const bool &value) {
+ return value ? "true" : "false";
+}
+
+template <>
+std::string to_str<std::string>(const std::string &value) {
+ return std::string("\"") + value + std::string("\"");
+}
+} // namespace
+
+struct FlagFunc {
+ const char *name;
+ const char *help;
+ const char *type;
+ std::string default_value;
+ std::function<void(const std::string &)> set_value;
+};
+
+namespace {
+
+using FlagMap = std::map<std::string, FlagFunc *>;
+using FlagList = std::vector<FlagFunc *>;
+
+FlagMap *GetFlagMap() {
+ static auto *flag_map = new FlagMap;
+ return flag_map;
+}
+
+FlagList *GetFlagList() {
+ static auto *flag_list = new FlagList;
+ return flag_list;
+}
+
+bool CommandLineGetFlag(int argc, char **argv, std::string *key,
+ std::string *value, int *used_args) {
+ key->clear();
+ value->clear();
+
+ *used_args = 1;
+ const char *start = argv[0];
+ if (start[0] != '-') return false;
+
+ ++start;
+ if (start[0] == '-') ++start;
+ const std::string arg = start;
+ const size_t n = arg.find("=");
+ if (n != std::string::npos) {
+ *key = arg.substr(0, n);
+ *value = arg.substr(n + 1, arg.size() - n);
+ return true;
+ }
+
+ key->assign(arg);
+ value->clear();
+
+ if (argc == 1) return true;
+
+ start = argv[1];
+ if (start[0] == '-') return true;
+
+ *used_args = 2;
+ value->assign(start);
+ return true;
+}
+
+std::string PrintHelp(const char *programname) {
+ std::ostringstream os;
+ os << PACKAGE_STRING << "\n\n";
+ os << "Usage: " << programname << " [options] files\n\n";
+
+ for (const auto *func : *GetFlagList()) {
+ os << " --" << func->name << " (" << func->help << ")";
+ os << " type: " << func->type << " default: " << func->default_value
+ << '\n';
+ }
+
+ os << "\n\n";
+
+ return os.str();
+}
+} // namespace
+
+void RegisterFlag(const std::string &name, FlagFunc *func) {
+ GetFlagList()->emplace_back(func);
+ GetFlagMap()->emplace(name, func);
+}
+} // namespace internal
+
+template <typename T>
+Flag<T>::Flag(const char *name, const char *type, const char *help,
+ const T &default_value)
+ : value_(default_value), func_(new internal::FlagFunc) {
+ func_->name = name;
+ func_->help = help;
+ func_->type = type;
+ func_->default_value = internal::to_str<T>(default_value);
+ func_->set_value = [this](const std::string &value) {
+ this->set_value_as_str(value);
+ };
+ RegisterFlag(name, func_.get());
+}
+
+template <typename T>
+Flag<T>::~Flag() {}
+
+template <typename T>
+const T &Flag<T>::value() const {
+ return value_;
+}
+
+template <typename T>
+void Flag<T>::set_value(const T &value) {
+ value_ = value;
+}
+
+template <typename T>
+void Flag<T>::set_value_as_str(const std::string &value_as_str) {
+ sentencepiece::string_util::lexical_cast<T>(value_as_str, &value_);
+}
+
+template <>
+void Flag<bool>::set_value_as_str(const std::string &value_as_str) {
+ if (value_as_str.empty())
+ value_ = true;
+ else
+ sentencepiece::string_util::lexical_cast<bool>(value_as_str, &value_);
+}
+
+template class Flag<std::string>;
+template class Flag<int32>;
+template class Flag<double>;
+template class Flag<bool>;
+template class Flag<int64>;
+template class Flag<uint64>;
+
+std::vector<char *> ParseCommandLine(int argc, char *argv[]) {
+ if (argc == 0) return {};
+
+ int used_argc = 0;
+ std::string key, value;
+ std::vector<char *> output_args;
+ output_args.reserve(argc);
+ output_args.push_back(argv[0]);
+
+ auto set_flag = [](const std::string &name, const std::string &value) {
+ const auto *flag_map = internal::GetFlagMap();
+ auto it = flag_map->find(name);
+ if (it == flag_map->end()) return false;
+ it->second->set_value(value);
+ return true;
+ };
+
+ for (int i = 1; i < argc; i += used_argc) {
+ if (!internal::CommandLineGetFlag(argc - i, argv + i, &key, &value,
+ &used_argc)) {
+ output_args.push_back(argv[i]);
+ continue;
+ }
+
+ if (!set_flag(key, value)) {
+ std::cerr << "Unknown/Invalid flag " << key << "\n\n"
+ << internal::PrintHelp(argv[0]);
+ sentencepiece::error::Exit(1);
+ }
+ }
+
+ if (absl::GetFlag(FLAGS_help)) {
+ std::cout << internal::PrintHelp(argv[0]);
+ sentencepiece::error::Exit(0);
+ } else if (absl::GetFlag(FLAGS_version)) {
+ std::cout << PACKAGE_STRING << " " << VERSION << std::endl;
+ sentencepiece::error::Exit(0);
+ }
+
+ return output_args;
+}
+} // namespace absl
diff --git a/third_party/absl/flags/flag.h b/third_party/absl/flags/flag.h
new file mode 100644
index 0000000..f3bf71d
--- /dev/null
+++ b/third_party/absl/flags/flag.h
@@ -0,0 +1,64 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#ifndef ABSL_FLAGS_FLAG_H_
+#define ABSL_FLAGS_FLAG_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace absl {
+namespace internal {
+struct FlagFunc;
+
+void RegisterFlag(const std::string &name, FlagFunc *func);
+} // namespace internal
+
+template <typename T>
+class Flag {
+ public:
+ Flag(const char *name, const char *type, const char *help,
+ const T &defautl_value);
+ virtual ~Flag();
+ const T &value() const;
+ void set_value(const T &value);
+ void set_value_as_str(const std::string &value_as_str);
+
+ private:
+ T value_;
+ std::unique_ptr<internal::FlagFunc> func_;
+};
+
+template <typename T>
+const T &GetFlag(const Flag<T> &flag) {
+ return flag.value();
+}
+
+template <typename T, typename V>
+void SetFlag(Flag<T> *flag, const V &v) {
+ const T value(v);
+ flag->set_value(value);
+}
+
+std::vector<char *> ParseCommandLine(int argc, char *argv[]);
+} // namespace absl
+
+#define ABSL_FLAG(Type, name, defautl_value, help) \
+ absl::Flag<Type> FLAGS_##name(#name, #Type, help, defautl_value);
+
+#define ABSL_DECLARE_FLAG(Type, name) extern absl::Flag<Type> FLAGS_##name;
+
+#endif // ABSL_FLAGS_FLAG_H_