Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
committerTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
commitf228e556b484015cdb29c467416dd95ca38360cd (patch)
tree5b2b31e2a4c81d9f701610ecdfcaa4309d52b5d1 /src/sentencepiece_trainer.cc
parent36a3b35e17f995c5fb77c0397122b0d6d41fef1b (diff)
Reimplement Trainer with Proto reflection
Diffstat (limited to 'src/sentencepiece_trainer.cc')
-rw-r--r--src/sentencepiece_trainer.cc314
1 files changed, 148 insertions, 166 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 5170f3a..bfc2f8e 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -13,7 +13,6 @@
// limitations under the License.!
#include "sentencepiece_trainer.h"
-#include <mutex>
#include <string>
#include "builder.h"
@@ -25,196 +24,179 @@
#include "trainer_factory.h"
#include "util.h"
-namespace {
-static const sentencepiece::TrainerSpec kDefaultTrainerSpec;
-static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
-} // namespace
-
-DEFINE_string(input, "", "comma separated list of input sentences");
-DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
- "Input format. Supported format is `text` or `tsv`.");
-DEFINE_string(model_prefix, "", "output model prefix");
-DEFINE_string(model_type, "unigram",
- "model algorithm: unigram, bpe, word or char");
-DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
-DEFINE_string(accept_language, "",
- "comma-separated list of languages this model can accept");
-DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
- "character coverage to determine the minimum symbols");
-DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
- "maximum size of sentences the trainer loads");
-DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
- "maximum size of sentences to make seed sentence piece");
-DEFINE_int32(training_sentence_size,
- kDefaultTrainerSpec.training_sentence_size(),
- "maximum size of sentences to train sentence pieces");
-DEFINE_int32(seed_sentencepiece_size,
- kDefaultTrainerSpec.seed_sentencepiece_size(),
- "the size of seed sentencepieces");
-DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
- "Keeps top shrinking_factor pieces with respect to the loss");
-DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
- "number of threads for training");
-DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
- "number of EM sub-iterations");
-DEFINE_int32(max_sentencepiece_length,
- kDefaultTrainerSpec.max_sentencepiece_length(),
- "maximum length of sentence piece");
-DEFINE_bool(split_by_unicode_script,
- kDefaultTrainerSpec.split_by_unicode_script(),
- "use Unicode script to split sentence pieces");
-DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
- "use a white space to split sentence pieces");
-DEFINE_string(control_symbols, "", "comma separated list of control symbols");
-DEFINE_string(user_defined_symbols, "",
- "comma separated list of user defined symbols");
-DEFINE_string(normalization_rule_name, "nfkc",
- "Normalization rule name. "
- "Choose from nfkc or identity");
-DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
-DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
- "Add dummy whitespace at the beginning of text");
-DEFINE_bool(remove_extra_whitespaces,
- kDefaultNormalizerSpec.remove_extra_whitespaces(),
- "Removes leading, trailing, and "
- "duplicate internal whitespace");
-DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
- "If set to false, --vocab_size is considered as a soft limit.");
-DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
- "Override UNK (<unk>) id.");
-DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
- "Override BOS (<s>) id. Set -1 to disable BOS.");
-DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
- "Override EOS (</s>) id. Set -1 to disable EOS.");
-DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
- "Override PAD (<pad>) id. Set -1 to disable PAD.");
-
-using sentencepiece::NormalizerSpec;
-using sentencepiece::TrainerSpec;
-using sentencepiece::normalizer::Builder;
-
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nfkc";
+} // namespace
+
+// static
+util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name(kDefaultNormalizerName);
+ Train(trainer_spec, normalizer_spec);
+ return util::OkStatus();
+}
-NormalizerSpec MakeNormalizerSpecFromFlags() {
- if (!FLAGS_normalization_rule_tsv.empty()) {
- const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
- FLAGS_normalization_rule_tsv);
- sentencepiece::NormalizerSpec spec;
- spec.set_name("user_defined");
- spec.set_precompiled_charsmap(
- sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
- return spec;
+// static
+util::Status SentencePieceTrainer::Train(
+ const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
+ auto copied_normalizer_spec = normalizer_spec;
+
+ if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
+ if (!copied_normalizer_spec.precompiled_charsmap().empty()) {
+ return util::InternalError("precompiled_charsmap is already defined.");
+ }
+
+ const auto chars_map = normalizer::Builder::BuildMapFromFile(
+ copied_normalizer_spec.normalization_rule_tsv());
+ copied_normalizer_spec.set_precompiled_charsmap(
+ normalizer::Builder::CompileCharsMap(chars_map));
+ copied_normalizer_spec.set_name("user_defined");
+ } else {
+ if (copied_normalizer_spec.name().empty()) {
+ copied_normalizer_spec.set_name(kDefaultNormalizerName);
+ }
+
+ if (copied_normalizer_spec.precompiled_charsmap().empty()) {
+ *(copied_normalizer_spec.mutable_precompiled_charsmap()) =
+ normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name())
+ .precompiled_charsmap();
+ }
}
- return sentencepiece::normalizer::Builder::GetNormalizerSpec(
- FLAGS_normalization_rule_name);
-}
+ auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
+ trainer->Train();
-TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) {
- const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
- {"unigram", TrainerSpec::UNIGRAM},
- {"bpe", TrainerSpec::BPE},
- {"word", TrainerSpec::WORD},
- {"char", TrainerSpec::CHAR}};
- return port::FindOrDie(kModelTypeMap, type);
+ return util::OkStatus();
}
-// Populates the value from flags to spec.
-#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name);
+// static
+util::Status SentencePieceTrainer::SetProtoField(
+ const std::string &field_name, const std::string &value,
+ google::protobuf::Message *message) {
+ const auto *descriptor = message->GetDescriptor();
+ const auto *reflection = message->GetReflection();
+
+ if (descriptor == nullptr || reflection == nullptr) {
+ return util::InternalError("Reflection is not supported.");
+ }
+
+ const auto *field = descriptor->FindFieldByName(std::string(field_name));
-#define SetNormalizerSpecFromFlag(name) \
- normalizer_spec->set_##name(FLAGS_##name);
+ if (field == nullptr) {
+ return util::NotFoundError(std::string("Unknown field name \"") +
+ field_name + "\" in " +
+ descriptor->DebugString());
+ }
-#define SetRepeatedTrainerSpecFromFlag(name) \
- if (!FLAGS_##name.empty()) { \
- for (const auto v : \
- sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
- trainer_spec->add_##name(v); \
- } \
+ std::vector<std::string> values = {value};
+ if (field->is_repeated()) values = string_util::Split(value, ",");
+
+#define SET_FIELD(METHOD_TYPE, v) \
+ if (field->is_repeated()) \
+ reflection->Add##METHOD_TYPE(message, field, v); \
+ else \
+ reflection->Set##METHOD_TYPE(message, field, v);
+
+#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \
+ EMPTY) \
+ case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \
+ CPP_TYPE v; \
+ if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \
+ return util::InvalidArgumentError(std::string("Cannot parse \"") + \
+ value + "\" as \"" + \
+ field->type_name() + "\"."); \
+ SET_FIELD(METHOD_TYPE, v); \
+ break; \
}
-void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
- NormalizerSpec *normalizer_spec) {
- CHECK_NOTNULL(trainer_spec);
- CHECK_NOTNULL(normalizer_spec);
-
- SetTrainerSpecFromFlag(input_format);
- SetTrainerSpecFromFlag(model_prefix);
- SetTrainerSpecFromFlag(vocab_size);
- SetTrainerSpecFromFlag(character_coverage);
- SetTrainerSpecFromFlag(input_sentence_size);
- SetTrainerSpecFromFlag(mining_sentence_size);
- SetTrainerSpecFromFlag(training_sentence_size);
- SetTrainerSpecFromFlag(seed_sentencepiece_size);
- SetTrainerSpecFromFlag(shrinking_factor);
- SetTrainerSpecFromFlag(num_threads);
- SetTrainerSpecFromFlag(num_sub_iterations);
- SetTrainerSpecFromFlag(max_sentencepiece_length);
- SetTrainerSpecFromFlag(split_by_unicode_script);
- SetTrainerSpecFromFlag(split_by_whitespace);
- SetTrainerSpecFromFlag(hard_vocab_limit);
- SetTrainerSpecFromFlag(unk_id);
- SetTrainerSpecFromFlag(bos_id);
- SetTrainerSpecFromFlag(eos_id);
- SetTrainerSpecFromFlag(pad_id);
- SetRepeatedTrainerSpecFromFlag(accept_language);
- SetRepeatedTrainerSpecFromFlag(control_symbols);
- SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
-
- *normalizer_spec = MakeNormalizerSpecFromFlags();
- SetNormalizerSpecFromFlag(add_dummy_prefix);
- SetNormalizerSpecFromFlag(remove_extra_whitespaces);
-
- for (const auto &filename :
- sentencepiece::string_util::Split(FLAGS_input, ",")) {
- trainer_spec->add_input(filename);
+ for (const auto &value : values) {
+ switch (field->cpp_type()) {
+ DEFINE_SET_FIELD(INT32, int32, i, Int32, "");
+ DEFINE_SET_FIELD(INT64, int64, i, Int64, "");
+ DEFINE_SET_FIELD(UINT32, uint32, i, UInt32, "");
+ DEFINE_SET_FIELD(UINT64, uint64, i, UInt64, "");
+ DEFINE_SET_FIELD(DOUBLE, double, d, Double, "");
+ DEFINE_SET_FIELD(FLOAT, float, f, Float, "");
+ DEFINE_SET_FIELD(BOOL, bool, b, Bool, "true");
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING:
+ SET_FIELD(String, value);
+ break;
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ const auto *enum_value =
+ field->enum_type()->FindValueByName(string_util::ToUpper(value));
+ if (enum_value == nullptr)
+ return util::InvalidArgumentError(
+ std::string("Unknown enumeration value of \"") + value +
+ "\" for field \"" + field->name() + "\".");
+ SET_FIELD(Enum, enum_value);
+ break;
+ }
+ default:
+ return util::UnimplementedError(std::string("Proto type \"") +
+ field->cpp_type_name() +
+ "\" is not supported.");
+ break;
+ }
}
- trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type));
+ return util::OkStatus();
}
-} // namespace
// static
-void SentencePieceTrainer::Train(int argc, char **argv) {
- TrainerSpec trainer_spec;
- NormalizerSpec normalizer_spec;
- {
- static std::mutex flags_mutex;
- std::lock_guard<std::mutex> lock(flags_mutex);
- sentencepiece::flags::ParseCommandLineFlags(argc, argv);
- CHECK_OR_HELP(input);
- CHECK_OR_HELP(model_prefix);
- MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec);
+util::Status SentencePieceTrainer::MergeSpecsFromArgs(
+ const std::string &args, TrainerSpec *trainer_spec,
+ NormalizerSpec *normalizer_spec) {
+ if (trainer_spec == nullptr || normalizer_spec == nullptr) {
+ return util::InternalError(
+ "`trainer_spec` and `normalizer_spec` must not be null.");
}
- SentencePieceTrainer::Train(trainer_spec, normalizer_spec);
-}
+ if (args.empty()) return util::OkStatus();
+
+ for (auto arg : string_util::SplitPiece(args, " ")) {
+ arg.Consume("--");
+ std::string key, value;
+ auto pos = arg.find("=");
+ if (pos == StringPiece::npos) {
+ key = arg.ToString();
+ } else {
+ key = arg.substr(0, pos).ToString();
+ value = arg.substr(pos + 1).ToString();
+ }
+
+ // Exception.
+ if (key == "normalization_rule_name") {
+ normalizer_spec->set_name(value);
+ continue;
+ }
+
+ const auto status_train = SetProtoField(key, value, trainer_spec);
+ if (status_train.ok()) continue;
+ if (!util::IsNotFound(status_train)) return status_train;
+
+ const auto status_norm = SetProtoField(key, value, normalizer_spec);
+ if (status_norm.ok()) continue;
+ if (!util::IsNotFound(status_norm)) return status_norm;
+
+ // Not found both in trainer_spec and normalizer_spec.
+ if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) {
+ return status_train;
+ }
+ }
-// static
-void SentencePieceTrainer::Train(const std::string &arg) {
- const std::vector<std::string> args =
- sentencepiece::string_util::Split(arg, " ");
- std::vector<char *> cargs(args.size() + 1);
- cargs[0] = const_cast<char *>("");
- for (size_t i = 0; i < args.size(); ++i)
- cargs[i + 1] = const_cast<char *>(args[i].data());
- SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]);
+ return util::OkStatus();
}
// static
-void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
- SentencePieceTrainer::Train(
- trainer_spec,
- normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName));
-}
+util::Status SentencePieceTrainer::Train(const std::string &args) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name(kDefaultNormalizerName);
-// static
-void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec,
- const NormalizerSpec &normalizer_spec) {
- auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec);
- trainer->Train();
+ CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
+
+ return Train(trainer_spec, normalizer_spec);
}
} // namespace sentencepiece