Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
committerTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
commitf228e556b484015cdb29c467416dd95ca38360cd (patch)
tree5b2b31e2a4c81d9f701610ecdfcaa4309d52b5d1 /src
parent36a3b35e17f995c5fb77c0397122b0d6d41fef1b (diff)
Reimplement Trainer with Proto reflection
Diffstat (limited to 'src')
-rw-r--r--src/flags.cc51
-rw-r--r--src/sentencepiece_model.proto6
-rw-r--r--src/sentencepiece_processor.cc1
-rw-r--r--src/sentencepiece_trainer.cc314
-rw-r--r--src/sentencepiece_trainer.h38
-rw-r--r--src/sentencepiece_trainer_test.cc89
-rw-r--r--src/spm_train_main.cc132
-rw-r--r--src/util.h46
-rw-r--r--src/util_test.cc25
9 files changed, 488 insertions, 214 deletions
diff --git a/src/flags.cc b/src/flags.cc
index bdc4571..0324c19 100644
--- a/src/flags.cc
+++ b/src/flags.cc
@@ -14,6 +14,7 @@
#include "flags.h"
#include "common.h"
+#include "util.h"
#include <algorithm>
#include <cctype>
@@ -44,23 +45,6 @@ FlagMap *GetFlagMap() {
return &flag_map;
}
-bool IsTrue(const std::string &value) {
- const char *kTrue[] = {"1", "t", "true", "y", "yes"};
- const char *kFalse[] = {"0", "f", "false", "n", "no"};
- std::string lower_value = value;
- std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
- ::tolower);
- for (size_t i = 0; i < 5; ++i) {
- if (lower_value == kTrue[i]) {
- return true;
- } else if (lower_value == kFalse[i]) {
- return false;
- }
- }
- LOG(FATAL) << "cannot parse boolean value: " << value;
- return false;
-}
-
bool SetFlag(const std::string &name, const std::string &value) {
auto it = GetFlagMap()->find(name);
if (it == GetFlagMap()->end()) {
@@ -85,31 +69,26 @@ bool SetFlag(const std::string &name, const std::string &value) {
}
}
+#define DEFINE_ARG(FLAG_TYPE, CPP_TYPE) \
+ case FLAG_TYPE: { \
+ CPP_TYPE *s = reinterpret_cast<CPP_TYPE *>(flag->storage); \
+ CHECK(string_util::lexical_cast<CPP_TYPE>(v, s)); \
+ break; \
+ }
+
switch (flag->type) {
- case I:
- *reinterpret_cast<int32 *>(flag->storage) = atoi(v.c_str());
- break;
- case B:
- *(reinterpret_cast<bool *>(flag->storage)) = IsTrue(v);
- break;
- case I64:
- *reinterpret_cast<int64 *>(flag->storage) = atoll(v.c_str());
- break;
- case U64:
- *reinterpret_cast<uint64 *>(flag->storage) = atoll(v.c_str());
- break;
- case D:
- *reinterpret_cast<double *>(flag->storage) = strtod(v.c_str(), nullptr);
- break;
- case S:
- *reinterpret_cast<std::string *>(flag->storage) = v;
- break;
+ DEFINE_ARG(I, int32);
+ DEFINE_ARG(B, bool);
+ DEFINE_ARG(I64, int64);
+ DEFINE_ARG(U64, uint64);
+ DEFINE_ARG(D, double);
+ DEFINE_ARG(S, std::string);
default:
break;
}
return true;
-}
+} // namespace
bool CommandLineGetFlag(int argc, char **argv, std::string *key,
std::string *value, int *used_args) {
diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto
index 3a3d5c4..cfb543c 100644
--- a/src/sentencepiece_model.proto
+++ b/src/sentencepiece_model.proto
@@ -175,6 +175,12 @@ message NormalizerSpec {
// This field must be true to train sentence piece model.
optional bool escape_whitespaces = 5 [ default = true ];
+ // Custom normalization rule file in TSV format.
+ // https://github.com/google/sentencepiece/blob/master/doc/normalization.md
+ // This field is only used in SentencePieceTrainer::Train() method, which
+ // compiles the rule into the binary rule stored in `precompiled_charsmap`.
+ optional string normalization_rule_tsv = 6;
+
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 593f581..6b29a56 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -31,6 +31,7 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
const char kUnknownSymbol[] = " \xE2\x81\x87 ";
+
} // namespace
SentencePieceProcessor::SentencePieceProcessor() {}
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 5170f3a..bfc2f8e 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -13,7 +13,6 @@
// limitations under the License.!
#include "sentencepiece_trainer.h"
-#include <mutex>
#include <string>
#include "builder.h"
@@ -25,196 +24,179 @@
#include "trainer_factory.h"
#include "util.h"
-namespace {
-static const sentencepiece::TrainerSpec kDefaultTrainerSpec;
-static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
-} // namespace
-
-DEFINE_string(input, "", "comma separated list of input sentences");
-DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
- "Input format. Supported format is `text` or `tsv`.");
-DEFINE_string(model_prefix, "", "output model prefix");
-DEFINE_string(model_type, "unigram",
- "model algorithm: unigram, bpe, word or char");
-DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
-DEFINE_string(accept_language, "",
- "comma-separated list of languages this model can accept");
-DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
- "character coverage to determine the minimum symbols");
-DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
- "maximum size of sentences the trainer loads");
-DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
- "maximum size of sentences to make seed sentence piece");
-DEFINE_int32(training_sentence_size,
- kDefaultTrainerSpec.training_sentence_size(),
- "maximum size of sentences to train sentence pieces");
-DEFINE_int32(seed_sentencepiece_size,
- kDefaultTrainerSpec.seed_sentencepiece_size(),
- "the size of seed sentencepieces");
-DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
- "Keeps top shrinking_factor pieces with respect to the loss");
-DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
- "number of threads for training");
-DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
- "number of EM sub-iterations");
-DEFINE_int32(max_sentencepiece_length,
- kDefaultTrainerSpec.max_sentencepiece_length(),
- "maximum length of sentence piece");
-DEFINE_bool(split_by_unicode_script,
- kDefaultTrainerSpec.split_by_unicode_script(),
- "use Unicode script to split sentence pieces");
-DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
- "use a white space to split sentence pieces");
-DEFINE_string(control_symbols, "", "comma separated list of control symbols");
-DEFINE_string(user_defined_symbols, "",
- "comma separated list of user defined symbols");
-DEFINE_string(normalization_rule_name, "nfkc",
- "Normalization rule name. "
- "Choose from nfkc or identity");
-DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
-DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
- "Add dummy whitespace at the beginning of text");
-DEFINE_bool(remove_extra_whitespaces,
- kDefaultNormalizerSpec.remove_extra_whitespaces(),
- "Removes leading, trailing, and "
- "duplicate internal whitespace");
-DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
- "If set to false, --vocab_size is considered as a soft limit.");
-DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
- "Override UNK (<unk>) id.");
-DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
- "Override BOS (<s>) id. Set -1 to disable BOS.");
-DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
- "Override EOS (</s>) id. Set -1 to disable EOS.");
-DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
- "Override PAD (<pad>) id. Set -1 to disable PAD.");
-
-using sentencepiece::NormalizerSpec;
-using sentencepiece::TrainerSpec;
-using sentencepiece::normalizer::Builder;
-
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nfkc";
+} // namespace
+
+// static
+util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name(kDefaultNormalizerName);
+ Train(trainer_spec, normalizer_spec);
+ return util::OkStatus();
+}
-NormalizerSpec MakeNormalizerSpecFromFlags() {
- if (!FLAGS_normalization_rule_tsv.empty()) {
- const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
- FLAGS_normalization_rule_tsv);
- sentencepiece::NormalizerSpec spec;
- spec.set_name("user_defined");
- spec.set_precompiled_charsmap(
- sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
- return spec;
+// static
+util::Status SentencePieceTrainer::Train(
+ const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
+ auto copied_normalizer_spec = normalizer_spec;
+
+ if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
+ if (!copied_normalizer_spec.precompiled_charsmap().empty()) {
+ return util::InternalError("precompiled_charsmap is already defined.");
+ }
+
+ const auto chars_map = normalizer::Builder::BuildMapFromFile(
+ copied_normalizer_spec.normalization_rule_tsv());
+ copied_normalizer_spec.set_precompiled_charsmap(
+ normalizer::Builder::CompileCharsMap(chars_map));
+ copied_normalizer_spec.set_name("user_defined");
+ } else {
+ if (copied_normalizer_spec.name().empty()) {
+ copied_normalizer_spec.set_name(kDefaultNormalizerName);
+ }
+
+ if (copied_normalizer_spec.precompiled_charsmap().empty()) {
+ *(copied_normalizer_spec.mutable_precompiled_charsmap()) =
+ normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name())
+ .precompiled_charsmap();
+ }
}
- return sentencepiece::normalizer::Builder::GetNormalizerSpec(
- FLAGS_normalization_rule_name);
-}
+ auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
+ trainer->Train();
-TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) {
- const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
- {"unigram", TrainerSpec::UNIGRAM},
- {"bpe", TrainerSpec::BPE},
- {"word", TrainerSpec::WORD},
- {"char", TrainerSpec::CHAR}};
- return port::FindOrDie(kModelTypeMap, type);
+ return util::OkStatus();
}
-// Populates the value from flags to spec.
-#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name);
+// static
+util::Status SentencePieceTrainer::SetProtoField(
+ const std::string &field_name, const std::string &value,
+ google::protobuf::Message *message) {
+ const auto *descriptor = message->GetDescriptor();
+ const auto *reflection = message->GetReflection();
+
+ if (descriptor == nullptr || reflection == nullptr) {
+ return util::InternalError("Reflection is not supported.");
+ }
+
+ const auto *field = descriptor->FindFieldByName(std::string(field_name));
-#define SetNormalizerSpecFromFlag(name) \
- normalizer_spec->set_##name(FLAGS_##name);
+ if (field == nullptr) {
+ return util::NotFoundError(std::string("Unknown field name \"") +
+ field_name + "\" in " +
+ descriptor->DebugString());
+ }
-#define SetRepeatedTrainerSpecFromFlag(name) \
- if (!FLAGS_##name.empty()) { \
- for (const auto v : \
- sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
- trainer_spec->add_##name(v); \
- } \
+ std::vector<std::string> values = {value};
+ if (field->is_repeated()) values = string_util::Split(value, ",");
+
+#define SET_FIELD(METHOD_TYPE, v) \
+ if (field->is_repeated()) \
+ reflection->Add##METHOD_TYPE(message, field, v); \
+ else \
+ reflection->Set##METHOD_TYPE(message, field, v);
+
+#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
+ EMPTY) \
+ case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
+ CPP_TYPE v; \
+ if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
+ return util::InvalidArgumentError(std::string("Cannot parse \"") + \
+ value + "\" as \"" + \
+ field->type_name() + "\"."); \
+ SET_FIELD(METHOD_TYPE, v); \
+ break; \
}
-void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
- NormalizerSpec *normalizer_spec) {
- CHECK_NOTNULL(trainer_spec);
- CHECK_NOTNULL(normalizer_spec);
-
- SetTrainerSpecFromFlag(input_format);
- SetTrainerSpecFromFlag(model_prefix);
- SetTrainerSpecFromFlag(vocab_size);
- SetTrainerSpecFromFlag(character_coverage);
- SetTrainerSpecFromFlag(input_sentence_size);
- SetTrainerSpecFromFlag(mining_sentence_size);
- SetTrainerSpecFromFlag(training_sentence_size);
- SetTrainerSpecFromFlag(seed_sentencepiece_size);
- SetTrainerSpecFromFlag(shrinking_factor);
- SetTrainerSpecFromFlag(num_threads);
- SetTrainerSpecFromFlag(num_sub_iterations);
- SetTrainerSpecFromFlag(max_sentencepiece_length);
- SetTrainerSpecFromFlag(split_by_unicode_script);
- SetTrainerSpecFromFlag(split_by_whitespace);
- SetTrainerSpecFromFlag(hard_vocab_limit);
- SetTrainerSpecFromFlag(unk_id);
- SetTrainerSpecFromFlag(bos_id);
- SetTrainerSpecFromFlag(eos_id);
- SetTrainerSpecFromFlag(pad_id);
- SetRepeatedTrainerSpecFromFlag(accept_language);
- SetRepeatedTrainerSpecFromFlag(control_symbols);
- SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
-
- *normalizer_spec = MakeNormalizerSpecFromFlags();
- SetNormalizerSpecFromFlag(add_dummy_prefix);
- SetNormalizerSpecFromFlag(remove_extra_whitespaces);
-
- for (const auto &filename :
- sentencepiece::string_util::Split(FLAGS_input, ",")) {
- trainer_spec->add_input(filename);
+ for (const auto &value : values) {
+ switch (field->cpp_type()) {
+ DEFINE_SET_FIELD(INT32, int32, i, Int32, "");
+ DEFINE_SET_FIELD(INT64, int64, i, Int64, "");
+ DEFINE_SET_FIELD(UINT32, uint32, i, UInt32, "");
+ DEFINE_SET_FIELD(UINT64, uint64, i, UInt64, "");
+ DEFINE_SET_FIELD(DOUBLE, double, d, Double, "");
+ DEFINE_SET_FIELD(FLOAT, float, f, Float, "");
+ DEFINE_SET_FIELD(BOOL, bool, b, Bool, "true");
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING:
+ SET_FIELD(String, value);
+ break;
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ const auto *enum_value =
+ field->enum_type()->FindValueByName(string_util::ToUpper(value));
+ if (enum_value == nullptr)
+ return util::InvalidArgumentError(
+ std::string("Unknown enumeration value of \"") + value +
+ "\" for field \"" + field->name() + "\".");
+ SET_FIELD(Enum, enum_value);
+ break;
+ }
+ default:
+ return util::UnimplementedError(std::string("Proto type \"") +
+ field->cpp_type_name() +
+ "\" is not supported.");
+ break;
+ }
}
- trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type));
+ return util::OkStatus();
}
-} // namespace
// static
-void SentencePieceTrainer::Train(int argc, char **argv) {
- TrainerSpec trainer_spec;
- NormalizerSpec normalizer_spec;
- {
- static std::mutex flags_mutex;
- std::lock_guard<std::mutex> lock(flags_mutex);
- sentencepiece::flags::ParseCommandLineFlags(argc, argv);
- CHECK_OR_HELP(input);
- CHECK_OR_HELP(model_prefix);
- MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec);
+util::Status SentencePieceTrainer::MergeSpecsFromArgs(
+ const std::string &args, TrainerSpec *trainer_spec,
+ NormalizerSpec *normalizer_spec) {
+ if (trainer_spec == nullptr || normalizer_spec == nullptr) {
+ return util::InternalError(
+ "`trainer_spec` and `normalizer_spec` must not be null.");
}
- SentencePieceTrainer::Train(trainer_spec, normalizer_spec);
-}
+ if (args.empty()) return util::OkStatus();
+
+ for (auto arg : string_util::SplitPiece(args, " ")) {
+ arg.Consume("--");
+ std::string key, value;
+ auto pos = arg.find("=");
+ if (pos == StringPiece::npos) {
+ key = arg.ToString();
+ } else {
+ key = arg.substr(0, pos).ToString();
+ value = arg.substr(pos + 1).ToString();
+ }
+
+ // Exception.
+ if (key == "normalization_rule_name") {
+ normalizer_spec->set_name(value);
+ continue;
+ }
+
+ const auto status_train = SetProtoField(key, value, trainer_spec);
+ if (status_train.ok()) continue;
+ if (!util::IsNotFound(status_train)) return status_train;
+
+ const auto status_norm = SetProtoField(key, value, normalizer_spec);
+ if (status_norm.ok()) continue;
+ if (!util::IsNotFound(status_norm)) return status_norm;
+
+ // Not found both in trainer_spec and normalizer_spec.
+ if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) {
+ return status_train;
+ }
+ }
-// static
-void SentencePieceTrainer::Train(const std::string &arg) {
- const std::vector<std::string> args =
- sentencepiece::string_util::Split(arg, " ");
- std::vector<char *> cargs(args.size() + 1);
- cargs[0] = const_cast<char *>("");
- for (size_t i = 0; i < args.size(); ++i)
- cargs[i + 1] = const_cast<char *>(args[i].data());
- SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]);
+ return util::OkStatus();
}
// static
-void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
- SentencePieceTrainer::Train(
- trainer_spec,
- normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName));
-}
+util::Status SentencePieceTrainer::Train(const std::string &args) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name(kDefaultNormalizerName);
-// static
-void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec,
- const NormalizerSpec &normalizer_spec) {
- auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec);
- trainer->Train();
+ CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
+
+ return Train(trainer_spec, normalizer_spec);
}
} // namespace sentencepiece
diff --git a/src/sentencepiece_trainer.h b/src/sentencepiece_trainer.h
index a3399da..9285b4d 100644
--- a/src/sentencepiece_trainer.h
+++ b/src/sentencepiece_trainer.h
@@ -16,6 +16,13 @@
#define SENTENCEPIECE_TRAINER_H_
#include <string>
+#include "sentencepiece_processor.h"
+
+namespace google {
+namespace protobuf {
+class Message;
+} // namespace protobuf
+} // namespace google
namespace sentencepiece {
@@ -24,21 +31,32 @@ class NormalizerSpec;
class SentencePieceTrainer {
public:
- // Entry point for main function.
- static void Train(int argc, char **argv);
-
- // Train from params with a single line.
- // "--input=foo --model_prefix=m --vocab_size=1024"
- static void Train(const std::string &arg);
-
// Trains SentencePiece model with `trainer_spec`.
// Default `normalizer_spec` is used.
- static void Train(const TrainerSpec &trainer_spec);
+ static util::Status Train(const TrainerSpec &trainer_spec);
// Trains SentencePiece model with `trainer_spec` and
// `normalizer_spec`.
- static void Train(const TrainerSpec &trainer_spec,
- const NormalizerSpec &normalizer_spec);
+ static util::Status Train(const TrainerSpec &trainer_spec,
+ const NormalizerSpec &normalizer_spec);
+
+ // Trains SentencePiece model with command-line string in `args`,
+ // e.g.,
+ // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
+ static util::Status Train(const std::string &args);
+
+ // Overrides `trainer_spec` and `normalizer_spec` with the
+ // command-line string in `args`.
+ static util::Status MergeSpecsFromArgs(const std::string &args,
+ TrainerSpec *trainer_spec,
+ NormalizerSpec *normalizer_spec);
+
+ // Helper function to set `field_name=value` in `message`.
+ // When `field_name` is repeated, multiple values can be passed
+ // with comma-separated values. `field_name` must not be a nested message.
+ static util::Status SetProtoField(const std::string &field_name,
+ const std::string &value,
+ google::protobuf::Message *message);
SentencePieceTrainer() = delete;
~SentencePieceTrainer() = delete;
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index 0b4c12b..0c2107d 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -15,6 +15,7 @@
#include "sentencepiece_trainer.h"
#include "sentencepiece_model.pb.h"
#include "testharness.h"
+#include "util.h"
namespace sentencepiece {
namespace {
@@ -39,12 +40,98 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
"--normalization_rule_tsv=../data/nfkc.tsv");
}
+TEST(SentencePieceTrainerTest, TrainErrorTest) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_normalization_rule_tsv("foo.tsv");
+ normalizer_spec.set_precompiled_charsmap("foo");
+ EXPECT_NOT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
+}
+
TEST(SentencePieceTrainerTest, TrainTest) {
TrainerSpec trainer_spec;
trainer_spec.add_input("../data/botchan.txt");
trainer_spec.set_model_prefix("m");
trainer_spec.set_vocab_size(1000);
- SentencePieceTrainer::Train(trainer_spec);
+ NormalizerSpec normalizer_spec;
+ EXPECT_OK(SentencePieceTrainer::Train(trainer_spec, normalizer_spec));
+ EXPECT_OK(SentencePieceTrainer::Train(trainer_spec));
+}
+
+TEST(SentencePieceTrainerTest, SetProtoFieldTest) {
+ TrainerSpec spec;
+
+ EXPECT_NOT_OK(SentencePieceTrainer::SetProtoField("dummy", "1000", &spec));
+
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("vocab_size", "1000", &spec));
+ EXPECT_EQ(1000, spec.vocab_size());
+ EXPECT_NOT_OK(
+ SentencePieceTrainer::SetProtoField("vocab_size", "UNK", &spec));
+
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "TSV", &spec));
+ EXPECT_EQ("TSV", spec.input_format());
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("input_format", "123", &spec));
+ EXPECT_EQ("123", spec.input_format());
+
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("split_by_whitespace", "false",
+ &spec));
+ EXPECT_FALSE(spec.split_by_whitespace());
+ EXPECT_OK(
+ SentencePieceTrainer::SetProtoField("split_by_whitespace", "", &spec));
+ EXPECT_TRUE(spec.split_by_whitespace());
+
+ EXPECT_OK(
+ SentencePieceTrainer::SetProtoField("character_coverage", "0.5", &spec));
+ EXPECT_NEAR(spec.character_coverage(), 0.5, 0.001);
+ EXPECT_NOT_OK(
+ SentencePieceTrainer::SetProtoField("character_coverage", "UNK", &spec));
+
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("input", "foo,bar,buz", &spec));
+ EXPECT_EQ(3, spec.input_size());
+ EXPECT_EQ("foo", spec.input(0));
+ EXPECT_EQ("bar", spec.input(1));
+ EXPECT_EQ("buz", spec.input(2));
+
+ EXPECT_OK(SentencePieceTrainer::SetProtoField("model_type", "BPE", &spec));
+ EXPECT_NOT_OK(
+ SentencePieceTrainer::SetProtoField("model_type", "UNK", &spec));
+
+ // Nested message is not supported.
+ ModelProto proto;
+ EXPECT_NOT_OK(
+ SentencePieceTrainer::SetProtoField("trainer_spec", "UNK", &proto));
}
+
+TEST(SentencePieceTrainerTest, MergeSpecsFromArgs) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+ EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", nullptr, nullptr));
+
+ EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs("", &trainer_spec,
+ &normalizer_spec));
+
+ EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--unknown=BPE", &trainer_spec, &normalizer_spec));
+
+ EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--vocab_size=UNK", &trainer_spec, &normalizer_spec));
+
+ EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--model_type=UNK", &trainer_spec, &normalizer_spec));
+
+ EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--model_type=bpe", &trainer_spec, &normalizer_spec));
+
+ EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--split_by_whitespace", &trainer_spec, &normalizer_spec));
+
+ EXPECT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--normalization_rule_name=foo", &trainer_spec, &normalizer_spec));
+ EXPECT_EQ("foo", normalizer_spec.name());
+
+ EXPECT_NOT_OK(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--vocab_size=UNK", &trainer_spec, &normalizer_spec));
+}
+
} // namespace
} // namespace sentencepiece
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index 6f9f917..761cc15 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -12,10 +12,140 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "builder.h"
+#include "flags.h"
#include "sentencepiece_trainer.h"
+#include "util.h"
+
+using sentencepiece::NormalizerSpec;
+using sentencepiece::TrainerSpec;
+using sentencepiece::normalizer::Builder;
+
+namespace {
+static sentencepiece::TrainerSpec kDefaultTrainerSpec;
+static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
+} // namespace
+
+DEFINE_string(input, "", "comma separated list of input sentences");
+DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
+ "Input format. Supported format is `text` or `tsv`.");
+DEFINE_string(model_prefix, "", "output model prefix");
+DEFINE_string(model_type, "unigram",
+ "model algorithm: unigram, bpe, word or char");
+DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
+DEFINE_string(accept_language, "",
+ "comma-separated list of languages this model can accept");
+DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
+ "character coverage to determine the minimum symbols");
+DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
+ "maximum size of sentences the trainer loads");
+DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
+ "maximum size of sentences to make seed sentence piece");
+DEFINE_int32(training_sentence_size,
+ kDefaultTrainerSpec.training_sentence_size(),
+ "maximum size of sentences to train sentence pieces");
+DEFINE_int32(seed_sentencepiece_size,
+ kDefaultTrainerSpec.seed_sentencepiece_size(),
+ "the size of seed sentencepieces");
+DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
+ "Keeps top shrinking_factor pieces with respect to the loss");
+DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
+ "number of threads for training");
+DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
+ "number of EM sub-iterations");
+DEFINE_int32(max_sentencepiece_length,
+ kDefaultTrainerSpec.max_sentencepiece_length(),
+ "maximum length of sentence piece");
+DEFINE_bool(split_by_unicode_script,
+ kDefaultTrainerSpec.split_by_unicode_script(),
+ "use Unicode script to split sentence pieces");
+DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
+ "use a white space to split sentence pieces");
+DEFINE_string(control_symbols, "", "comma separated list of control symbols");
+DEFINE_string(user_defined_symbols, "",
+ "comma separated list of user defined symbols");
+DEFINE_string(normalization_rule_name, "nfkc",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
+DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
+DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
+ "Add dummy whitespace at the beginning of text");
+DEFINE_bool(remove_extra_whitespaces,
+ kDefaultNormalizerSpec.remove_extra_whitespaces(),
+ "Removes leading, trailing, and "
+ "duplicate internal whitespace");
+DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
+ "If set to false, --vocab_size is considered as a soft limit.");
+DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
+DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
+ "Override BOS (<s>) id. Set -1 to disable BOS.");
+DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
+ "Override EOS (</s>) id. Set -1 to disable EOS.");
+DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
+ "Override PAD (<pad>) id. Set -1 to disable PAD.");
int main(int argc, char *argv[]) {
- sentencepiece::SentencePieceTrainer::Train(argc, argv);
+ sentencepiece::flags::ParseCommandLineFlags(argc, argv);
+ sentencepiece::TrainerSpec trainer_spec;
+ sentencepiece::NormalizerSpec normalizer_spec;
+
+ CHECK_OR_HELP(input);
+ CHECK_OR_HELP(model_prefix);
+
+// Populates the value from flags to spec.
+#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
+
+#define SetNormalizerSpecFromFlag(name) \
+ normalizer_spec.set_##name(FLAGS_##name);
+
+#define SetRepeatedTrainerSpecFromFlag(name) \
+ if (!FLAGS_##name.empty()) { \
+ for (const auto v : \
+ sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
+ trainer_spec.add_##name(v); \
+ } \
+ }
+
+ SetTrainerSpecFromFlag(input_format);
+ SetTrainerSpecFromFlag(model_prefix);
+ SetTrainerSpecFromFlag(vocab_size);
+ SetTrainerSpecFromFlag(character_coverage);
+ SetTrainerSpecFromFlag(input_sentence_size);
+ SetTrainerSpecFromFlag(mining_sentence_size);
+ SetTrainerSpecFromFlag(training_sentence_size);
+ SetTrainerSpecFromFlag(seed_sentencepiece_size);
+ SetTrainerSpecFromFlag(shrinking_factor);
+ SetTrainerSpecFromFlag(num_threads);
+ SetTrainerSpecFromFlag(num_sub_iterations);
+ SetTrainerSpecFromFlag(max_sentencepiece_length);
+ SetTrainerSpecFromFlag(split_by_unicode_script);
+ SetTrainerSpecFromFlag(split_by_whitespace);
+ SetTrainerSpecFromFlag(hard_vocab_limit);
+ SetTrainerSpecFromFlag(unk_id);
+ SetTrainerSpecFromFlag(bos_id);
+ SetTrainerSpecFromFlag(eos_id);
+ SetTrainerSpecFromFlag(pad_id);
+ SetRepeatedTrainerSpecFromFlag(input);
+ SetRepeatedTrainerSpecFromFlag(accept_language);
+ SetRepeatedTrainerSpecFromFlag(control_symbols);
+ SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
+
+ normalizer_spec.set_name(FLAGS_normalization_rule_name);
+ SetNormalizerSpecFromFlag(normalization_rule_tsv);
+ SetNormalizerSpecFromFlag(add_dummy_prefix);
+ SetNormalizerSpecFromFlag(remove_extra_whitespaces);
+
+ const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
+ {"unigram", TrainerSpec::UNIGRAM},
+ {"bpe", TrainerSpec::BPE},
+ {"word", TrainerSpec::WORD},
+ {"char", TrainerSpec::CHAR}};
+
+ trainer_spec.set_model_type(
+ sentencepiece::port::FindOrDie(kModelTypeMap, FLAGS_model_type));
+
+ CHECK_OK(sentencepiece::SentencePieceTrainer::Train(trainer_spec,
+ normalizer_spec));
return 0;
}
diff --git a/src/util.h b/src/util.h
index b882b57..cf30882 100644
--- a/src/util.h
+++ b/src/util.h
@@ -37,6 +37,52 @@ std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
// String utilities
namespace string_util {
+inline std::string ToLower(StringPiece arg) {
+ std::string lower_value = arg.ToString();
+ std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
+ ::tolower);
+ return lower_value;
+}
+
+inline std::string ToUpper(StringPiece arg) {
+ std::string upper_value = arg.ToString();
+ std::transform(upper_value.begin(), upper_value.end(), upper_value.begin(),
+ ::toupper);
+ return upper_value;
+}
+
+template <typename Target>
+inline bool lexical_cast(StringPiece arg, Target *result) {
+ std::stringstream ss;
+ return (ss << arg.data() && ss >> *result);
+}
+
+template <>
+inline bool lexical_cast(StringPiece arg, bool *result) {
+ const char *kTrue[] = {"1", "t", "true", "y", "yes"};
+ const char *kFalse[] = {"0", "f", "false", "n", "no"};
+ std::string lower_value = arg.ToString();
+ std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
+ ::tolower);
+ for (size_t i = 0; i < 5; ++i) {
+ if (lower_value == kTrue[i]) {
+ *result = true;
+ return true;
+ } else if (lower_value == kFalse[i]) {
+ *result = false;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+template <>
+inline bool lexical_cast(StringPiece arg, std::string *result) {
+ *result = arg.ToString();
+ return true;
+}
+
std::vector<std::string> Split(const std::string &str,
const std::string &delim);
diff --git a/src/util_test.cc b/src/util_test.cc
index 7720acb..224ea50 100644
--- a/src/util_test.cc
+++ b/src/util_test.cc
@@ -18,6 +18,31 @@
namespace sentencepiece {
+TEST(UtilTest, LexicalCastTest) {
+ bool b = false;
+ EXPECT_TRUE(string_util::lexical_cast<bool>("true", &b));
+ EXPECT_TRUE(b);
+ EXPECT_TRUE(string_util::lexical_cast<bool>("false", &b));
+ EXPECT_FALSE(b);
+ EXPECT_FALSE(string_util::lexical_cast<bool>("UNK", &b));
+
+ int32 n = 0;
+ EXPECT_TRUE(string_util::lexical_cast<int32>("123", &n));
+ EXPECT_EQ(123, n);
+ EXPECT_TRUE(string_util::lexical_cast<int32>("-123", &n));
+ EXPECT_EQ(-123, n);
+ EXPECT_FALSE(string_util::lexical_cast<int32>("UNK", &n));
+
+ double d = 0.0;
+ EXPECT_TRUE(string_util::lexical_cast<double>("123.4", &d));
+ EXPECT_NEAR(123.4, d, 0.001);
+ EXPECT_FALSE(string_util::lexical_cast<double>("UNK", &d));
+
+ std::string s;
+ EXPECT_TRUE(string_util::lexical_cast<std::string>("123.4", &s));
+ EXPECT_EQ("123.4", s);
+}
+
TEST(UtilTest, CheckNotNullTest) {
int a = 0;
CHECK_NOTNULL(&a);