diff options
author | Taku Kudo <taku@google.com> | 2018-04-30 18:00:35 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-04-30 18:00:35 +0300 |
commit | f228e556b484015cdb29c467416dd95ca38360cd (patch) | |
tree | 5b2b31e2a4c81d9f701610ecdfcaa4309d52b5d1 /src/sentencepiece_trainer.cc | |
parent | 36a3b35e17f995c5fb77c0397122b0d6d41fef1b (diff) |
Reimplement Trainer with Proto reflection
Diffstat (limited to 'src/sentencepiece_trainer.cc')
-rw-r--r-- | src/sentencepiece_trainer.cc | 314 |
1 files changed, 148 insertions, 166 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 5170f3a..bfc2f8e 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -13,7 +13,6 @@ // limitations under the License.! #include "sentencepiece_trainer.h" -#include <mutex> #include <string> #include "builder.h" @@ -25,196 +24,179 @@ #include "trainer_factory.h" #include "util.h" -namespace { -static const sentencepiece::TrainerSpec kDefaultTrainerSpec; -static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec; -} // namespace - -DEFINE_string(input, "", "comma separated list of input sentences"); -DEFINE_string(input_format, kDefaultTrainerSpec.input_format(), - "Input format. Supported format is `text` or `tsv`."); -DEFINE_string(model_prefix, "", "output model prefix"); -DEFINE_string(model_type, "unigram", - "model algorithm: unigram, bpe, word or char"); -DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size"); -DEFINE_string(accept_language, "", - "comma-separated list of languages this model can accept"); -DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(), - "character coverage to determine the minimum symbols"); -DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(), - "maximum size of sentences the trainer loads"); -DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(), - "maximum size of sentences to make seed sentence piece"); -DEFINE_int32(training_sentence_size, - kDefaultTrainerSpec.training_sentence_size(), - "maximum size of sentences to train sentence pieces"); -DEFINE_int32(seed_sentencepiece_size, - kDefaultTrainerSpec.seed_sentencepiece_size(), - "the size of seed sentencepieces"); -DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(), - "Keeps top shrinking_factor pieces with respect to the loss"); -DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(), - "number of threads for training"); -DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(), - "number of EM sub-iterations"); -DEFINE_int32(max_sentencepiece_length, - kDefaultTrainerSpec.max_sentencepiece_length(), - "maximum length of sentence piece"); -DEFINE_bool(split_by_unicode_script, - kDefaultTrainerSpec.split_by_unicode_script(), - "use Unicode script to split sentence pieces"); -DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(), - "use a white space to split sentence pieces"); -DEFINE_string(control_symbols, "", "comma separated list of control symbols"); -DEFINE_string(user_defined_symbols, "", - "comma separated list of user defined symbols"); -DEFINE_string(normalization_rule_name, "nfkc", - "Normalization rule name. " - "Choose from nfkc or identity"); -DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. "); -DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(), - "Add dummy whitespace at the beginning of text"); -DEFINE_bool(remove_extra_whitespaces, - kDefaultNormalizerSpec.remove_extra_whitespaces(), - "Removes leading, trailing, and " - "duplicate internal whitespace"); -DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(), - "If set to false, --vocab_size is considered as a soft limit."); -DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), - "Override UNK (<unk>) id."); -DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(), - "Override BOS (<s>) id. Set -1 to disable BOS."); -DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(), - "Override EOS (</s>) id. Set -1 to disable EOS."); -DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(), - "Override PAD (<pad>) id. Set -1 to disable PAD."); - -using sentencepiece::NormalizerSpec; -using sentencepiece::TrainerSpec; -using sentencepiece::normalizer::Builder; - namespace sentencepiece { namespace { static constexpr char kDefaultNormalizerName[] = "nfkc"; +} // namespace + +// static +util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) { + NormalizerSpec normalizer_spec; + normalizer_spec.set_name(kDefaultNormalizerName); + Train(trainer_spec, normalizer_spec); + return util::OkStatus(); +} -NormalizerSpec MakeNormalizerSpecFromFlags() { - if (!FLAGS_normalization_rule_tsv.empty()) { - const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile( - FLAGS_normalization_rule_tsv); - sentencepiece::NormalizerSpec spec; - spec.set_name("user_defined"); - spec.set_precompiled_charsmap( - sentencepiece::normalizer::Builder::CompileCharsMap(chars_map)); - return spec; +// static +util::Status SentencePieceTrainer::Train( + const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) { + auto copied_normalizer_spec = normalizer_spec; + + if (!copied_normalizer_spec.normalization_rule_tsv().empty()) { + if (!copied_normalizer_spec.precompiled_charsmap().empty()) { + return util::InternalError("precompiled_charsmap is already defined."); + } + + const auto chars_map = normalizer::Builder::BuildMapFromFile( + copied_normalizer_spec.normalization_rule_tsv()); + copied_normalizer_spec.set_precompiled_charsmap( + normalizer::Builder::CompileCharsMap(chars_map)); + copied_normalizer_spec.set_name("user_defined"); + } else { + if (copied_normalizer_spec.name().empty()) { + copied_normalizer_spec.set_name(kDefaultNormalizerName); + } + + if (copied_normalizer_spec.precompiled_charsmap().empty()) { + *(copied_normalizer_spec.mutable_precompiled_charsmap()) = + normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name()) + .precompiled_charsmap(); + } } - return sentencepiece::normalizer::Builder::GetNormalizerSpec( - FLAGS_normalization_rule_name); -} + auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec); + trainer->Train(); -TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) { - const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = { - {"unigram", TrainerSpec::UNIGRAM}, - {"bpe", TrainerSpec::BPE}, - {"word", TrainerSpec::WORD}, - {"char", TrainerSpec::CHAR}}; - return port::FindOrDie(kModelTypeMap, type); + return util::OkStatus(); } -// Populates the value from flags to spec. -#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name); +// static +util::Status SentencePieceTrainer::SetProtoField( + const std::string &field_name, const std::string &value, + google::protobuf::Message *message) { + const auto *descriptor = message->GetDescriptor(); + const auto *reflection = message->GetReflection(); + + if (descriptor == nullptr || reflection == nullptr) { + return util::InternalError("Reflection is not supported."); + } + + const auto *field = descriptor->FindFieldByName(std::string(field_name)); -#define SetNormalizerSpecFromFlag(name) \ - normalizer_spec->set_##name(FLAGS_##name); + if (field == nullptr) { + return util::NotFoundError(std::string("Unknown field name \"") + + field_name + "\" in " + + descriptor->DebugString()); + } -#define SetRepeatedTrainerSpecFromFlag(name) \ - if (!FLAGS_##name.empty()) { \ - for (const auto v : \ - sentencepiece::string_util::Split(FLAGS_##name, ",")) { \ - trainer_spec->add_##name(v); \ - } \ + std::vector<std::string> values = {value}; + if (field->is_repeated()) values = string_util::Split(value, ","); + +#define SET_FIELD(METHOD_TYPE, v) \ + if (field->is_repeated()) \ + reflection->Add##METHOD_TYPE(message, field, v); \ + else \ + reflection->Set##METHOD_TYPE(message, field, v); + +#define DEFINE_SET_FIELD(PROTO_TYPE, CPP_TYPE, FUNC_PREFIX, METHOD_TYPE, \ + EMPTY) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##PROTO_TYPE: { \ + CPP_TYPE v; \ + if (!string_util::lexical_cast(value.empty() ? EMPTY : value, &v)) \ + return util::InvalidArgumentError(std::string("Cannot parse \"") + \ + value + "\" as \"" + \ + field->type_name() + "\"."); \ + SET_FIELD(METHOD_TYPE, v); \ + break; \ } -void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec, - NormalizerSpec *normalizer_spec) { - CHECK_NOTNULL(trainer_spec); - CHECK_NOTNULL(normalizer_spec); - - SetTrainerSpecFromFlag(input_format); - SetTrainerSpecFromFlag(model_prefix); - SetTrainerSpecFromFlag(vocab_size); - SetTrainerSpecFromFlag(character_coverage); - SetTrainerSpecFromFlag(input_sentence_size); - SetTrainerSpecFromFlag(mining_sentence_size); - SetTrainerSpecFromFlag(training_sentence_size); - SetTrainerSpecFromFlag(seed_sentencepiece_size); - SetTrainerSpecFromFlag(shrinking_factor); - SetTrainerSpecFromFlag(num_threads); - SetTrainerSpecFromFlag(num_sub_iterations); - SetTrainerSpecFromFlag(max_sentencepiece_length); - SetTrainerSpecFromFlag(split_by_unicode_script); - SetTrainerSpecFromFlag(split_by_whitespace); - SetTrainerSpecFromFlag(hard_vocab_limit); - SetTrainerSpecFromFlag(unk_id); - SetTrainerSpecFromFlag(bos_id); - SetTrainerSpecFromFlag(eos_id); - SetTrainerSpecFromFlag(pad_id); - SetRepeatedTrainerSpecFromFlag(accept_language); - SetRepeatedTrainerSpecFromFlag(control_symbols); - SetRepeatedTrainerSpecFromFlag(user_defined_symbols); - - *normalizer_spec = MakeNormalizerSpecFromFlags(); - SetNormalizerSpecFromFlag(add_dummy_prefix); - SetNormalizerSpecFromFlag(remove_extra_whitespaces); - - for (const auto &filename : - sentencepiece::string_util::Split(FLAGS_input, ",")) { - trainer_spec->add_input(filename); + for (const auto &value : values) { + switch (field->cpp_type()) { + DEFINE_SET_FIELD(INT32, int32, i, Int32, ""); + DEFINE_SET_FIELD(INT64, int64, i, Int64, ""); + DEFINE_SET_FIELD(UINT32, uint32, i, UInt32, ""); + DEFINE_SET_FIELD(UINT64, uint64, i, UInt64, ""); + DEFINE_SET_FIELD(DOUBLE, double, d, Double, ""); + DEFINE_SET_FIELD(FLOAT, float, f, Float, ""); + DEFINE_SET_FIELD(BOOL, bool, b, Bool, "true"); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + SET_FIELD(String, value); + break; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + const auto *enum_value = + field->enum_type()->FindValueByName(string_util::ToUpper(value)); + if (enum_value == nullptr) + return util::InvalidArgumentError( + std::string("Unknown enumeration value of \"") + value + + "\" for field \"" + field->name() + "\"."); + SET_FIELD(Enum, enum_value); + break; + } + default: + return util::UnimplementedError(std::string("Proto type \"") + + field->cpp_type_name() + + "\" is not supported."); + break; + } } - trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type)); + return util::OkStatus(); } -} // namespace // static -void SentencePieceTrainer::Train(int argc, char **argv) { - TrainerSpec trainer_spec; - NormalizerSpec normalizer_spec; - { - static std::mutex flags_mutex; - std::lock_guard<std::mutex> lock(flags_mutex); - sentencepiece::flags::ParseCommandLineFlags(argc, argv); - CHECK_OR_HELP(input); - CHECK_OR_HELP(model_prefix); - MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec); +util::Status SentencePieceTrainer::MergeSpecsFromArgs( + const std::string &args, TrainerSpec *trainer_spec, + NormalizerSpec *normalizer_spec) { + if (trainer_spec == nullptr || normalizer_spec == nullptr) { + return util::InternalError( + "`trainer_spec` and `normalizer_spec` must not be null."); } - SentencePieceTrainer::Train(trainer_spec, normalizer_spec); -} + if (args.empty()) return util::OkStatus(); + + for (auto arg : string_util::SplitPiece(args, " ")) { + arg.Consume("--"); + std::string key, value; + auto pos = arg.find("="); + if (pos == StringPiece::npos) { + key = arg.ToString(); + } else { + key = arg.substr(0, pos).ToString(); + value = arg.substr(pos + 1).ToString(); + } + + // Exception. + if (key == "normalization_rule_name") { + normalizer_spec->set_name(value); + continue; + } + + const auto status_train = SetProtoField(key, value, trainer_spec); + if (status_train.ok()) continue; + if (!util::IsNotFound(status_train)) return status_train; + + const auto status_norm = SetProtoField(key, value, normalizer_spec); + if (status_norm.ok()) continue; + if (!util::IsNotFound(status_norm)) return status_norm; + + // Not found both in trainer_spec and normalizer_spec. + if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) { + return status_train; + } + } -// static -void SentencePieceTrainer::Train(const std::string &arg) { - const std::vector<std::string> args = - sentencepiece::string_util::Split(arg, " "); - std::vector<char *> cargs(args.size() + 1); - cargs[0] = const_cast<char *>(""); - for (size_t i = 0; i < args.size(); ++i) - cargs[i + 1] = const_cast<char *>(args[i].data()); - SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]); + return util::OkStatus(); } // static -void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) { - SentencePieceTrainer::Train( - trainer_spec, - normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName)); -} +util::Status SentencePieceTrainer::Train(const std::string &args) { + TrainerSpec trainer_spec; + NormalizerSpec normalizer_spec; + normalizer_spec.set_name(kDefaultNormalizerName); -// static -void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, - const NormalizerSpec &normalizer_spec) { - auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec); - trainer->Train(); + CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec)); + + return Train(trainer_spec, normalizer_spec); } } // namespace sentencepiece |