// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #include #include "filesystem.h" #include "init.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_trainer.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/ascii.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" #include "util.h" using sentencepiece::NormalizerSpec; using sentencepiece::TrainerSpec; namespace { static sentencepiece::TrainerSpec kDefaultTrainerSpec; static sentencepiece::NormalizerSpec kDefaultNormalizerSpec; } // namespace ABSL_FLAG(std::string, input, "", "comma separated list of input sentences"); ABSL_FLAG(std::string, input_format, kDefaultTrainerSpec.input_format(), "Input format. Supported format is `text` or `tsv`."); ABSL_FLAG(std::string, model_prefix, "", "output model prefix"); ABSL_FLAG(std::string, model_type, "unigram", "model algorithm: unigram, bpe, word or char"); ABSL_FLAG(int32, vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size"); ABSL_FLAG(std::string, accept_language, "", "comma-separated list of languages this model can accept"); ABSL_FLAG(int32, self_test_sample_size, kDefaultTrainerSpec.self_test_sample_size(), "the size of self test samples"); ABSL_FLAG(double, character_coverage, kDefaultTrainerSpec.character_coverage(), "character coverage to determine the minimum symbols"); ABSL_FLAG(int32, input_sentence_size, kDefaultTrainerSpec.input_sentence_size(), "maximum size of sentences the trainer loads"); ABSL_FLAG(bool, shuffle_input_sentence, kDefaultTrainerSpec.shuffle_input_sentence(), "Randomly sample input sentences in advance. Valid when " "--input_sentence_size > 0"); ABSL_FLAG(int32, seed_sentencepiece_size, kDefaultTrainerSpec.seed_sentencepiece_size(), "the size of seed sentencepieces"); ABSL_FLAG(double, shrinking_factor, kDefaultTrainerSpec.shrinking_factor(), "Keeps top shrinking_factor pieces with respect to the loss"); ABSL_FLAG(int32, num_threads, kDefaultTrainerSpec.num_threads(), "number of threads for training"); ABSL_FLAG(int32, num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(), "number of EM sub-iterations"); ABSL_FLAG(int32, max_sentencepiece_length, kDefaultTrainerSpec.max_sentencepiece_length(), "maximum length of sentence piece"); ABSL_FLAG(int32, max_sentence_length, kDefaultTrainerSpec.max_sentence_length(), "maximum length of sentence in byte"); ABSL_FLAG(bool, split_by_unicode_script, kDefaultTrainerSpec.split_by_unicode_script(), "use Unicode script to split sentence pieces"); ABSL_FLAG(bool, split_by_number, kDefaultTrainerSpec.split_by_number(), "split tokens by numbers (0-9)"); ABSL_FLAG(bool, split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(), "use a white space to split sentence pieces"); ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(), "split all digits (0-9) into separate pieces"); ABSL_FLAG(bool, treat_whitespace_as_suffix, kDefaultTrainerSpec.treat_whitespace_as_suffix(), "treat whitespace marker as suffix instead of prefix."); ABSL_FLAG(std::string, control_symbols, "", "comma separated list of control symbols"); ABSL_FLAG(std::string, control_symbols_file, "", "load control_symbols from file."); ABSL_FLAG(std::string, user_defined_symbols, "", "comma separated list of user defined symbols"); ABSL_FLAG(std::string, user_defined_symbols_file, "", "load user_defined_symbols from file."); ABSL_FLAG(std::string, required_chars, "", "UTF8 characters in this flag are always used in the character " "set regardless of --character_coverage"); ABSL_FLAG(std::string, required_chars_file, "", "load required_chars from file."); ABSL_FLAG(bool, byte_fallback, kDefaultTrainerSpec.byte_fallback(), "decompose unknown pieces into UTF-8 byte pieces"); ABSL_FLAG(bool, vocabulary_output_piece_score, kDefaultTrainerSpec.vocabulary_output_piece_score(), "Define score in vocab file"); ABSL_FLAG(std::string, normalization_rule_name, "nmt_nfkc", "Normalization rule name. " "Choose from nfkc or identity"); ABSL_FLAG(std::string, normalization_rule_tsv, "", "Normalization rule TSV file. "); ABSL_FLAG(std::string, denormalization_rule_tsv, "", "Denormalization rule TSV file."); ABSL_FLAG(bool, add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(), "Add dummy whitespace at the beginning of text"); ABSL_FLAG(bool, remove_extra_whitespaces, kDefaultNormalizerSpec.remove_extra_whitespaces(), "Removes leading, trailing, and " "duplicate internal whitespace"); ABSL_FLAG(bool, encode_unicode_case, kDefaultNormalizerSpec.encode_case(), "Handles unicode case"); ABSL_FLAG(bool, hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(), "If set to false, --vocab_size is considered as a soft limit."); ABSL_FLAG(bool, use_all_vocab, kDefaultTrainerSpec.use_all_vocab(), "If set to true, use all tokens as vocab. " "Valid for word/char models."); ABSL_FLAG(int32, unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK () id."); ABSL_FLAG(int32, bos_id, kDefaultTrainerSpec.bos_id(), "Override BOS () id. Set -1 to disable BOS."); ABSL_FLAG(int32, eos_id, kDefaultTrainerSpec.eos_id(), "Override EOS () id. Set -1 to disable EOS."); ABSL_FLAG(int32, pad_id, kDefaultTrainerSpec.pad_id(), "Override PAD () id. Set -1 to disable PAD."); ABSL_FLAG(std::string, unk_piece, kDefaultTrainerSpec.unk_piece(), "Override UNK () piece."); ABSL_FLAG(std::string, bos_piece, kDefaultTrainerSpec.bos_piece(), "Override BOS () piece."); ABSL_FLAG(std::string, eos_piece, kDefaultTrainerSpec.eos_piece(), "Override EOS () piece."); ABSL_FLAG(std::string, pad_piece, kDefaultTrainerSpec.pad_piece(), "Override PAD () piece."); ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(), "Dummy surface string for . In decoding is decoded to " "`unk_surface`."); ABSL_FLAG(bool, train_extremely_large_corpus, kDefaultTrainerSpec.train_extremely_large_corpus(), "Increase bit depth for unigram tokenization."); ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); sentencepiece::TrainerSpec trainer_spec; sentencepiece::NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); if (absl::GetFlag(FLAGS_random_seed) != -1) sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); auto load_lines = [](absl::string_view filename) { std::vector lines; auto input = sentencepiece::filesystem::NewReadableFile(filename); CHECK_OK(input->status()); std::string line; while (input->ReadLine(&line)) lines.emplace_back(line); return lines; }; // Populates the value from flags to spec. #define SetTrainerSpecFromFlag(name) \ trainer_spec.set_##name(absl::GetFlag(FLAGS_##name)); #define SetNormalizerSpecFromFlag(name) \ normalizer_spec.set_##name(absl::GetFlag(FLAGS_##name)); #define SetTrainerSpecFromFile(name) \ if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \ const auto lines = load_lines(absl::GetFlag(FLAGS_##name##_file)); \ trainer_spec.set_##name(absl::StrJoin(lines, "")); \ } #define SetRepeatedTrainerSpecFromFlag(name) \ if (!absl::GetFlag(FLAGS_##name).empty()) { \ for (const auto &v : \ sentencepiece::util::StrSplitAsCSV(absl::GetFlag(FLAGS_##name))) { \ trainer_spec.add_##name(v); \ } \ } #define SetRepeatedTrainerSpecFromFile(name) \ if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \ for (const auto &v : load_lines(absl::GetFlag(FLAGS_##name##_file))) { \ trainer_spec.add_##name(v); \ } \ } SetRepeatedTrainerSpecFromFlag(input); SetTrainerSpecFromFlag(input_format); SetTrainerSpecFromFlag(model_prefix); SetTrainerSpecFromFlag(vocab_size); SetTrainerSpecFromFlag(self_test_sample_size); SetTrainerSpecFromFlag(character_coverage); SetTrainerSpecFromFlag(input_sentence_size); SetTrainerSpecFromFlag(shuffle_input_sentence); SetTrainerSpecFromFlag(seed_sentencepiece_size); SetTrainerSpecFromFlag(shrinking_factor); SetTrainerSpecFromFlag(num_threads); SetTrainerSpecFromFlag(num_sub_iterations); SetTrainerSpecFromFlag(max_sentencepiece_length); SetTrainerSpecFromFlag(max_sentence_length); SetTrainerSpecFromFlag(split_by_unicode_script); SetTrainerSpecFromFlag(split_by_whitespace); SetTrainerSpecFromFlag(split_by_number); SetTrainerSpecFromFlag(split_digits); SetTrainerSpecFromFlag(byte_fallback); SetTrainerSpecFromFlag(treat_whitespace_as_suffix); SetTrainerSpecFromFlag(hard_vocab_limit); SetTrainerSpecFromFlag(use_all_vocab); SetTrainerSpecFromFlag(unk_id); SetTrainerSpecFromFlag(bos_id); SetTrainerSpecFromFlag(eos_id); SetTrainerSpecFromFlag(pad_id); SetTrainerSpecFromFlag(unk_piece); SetTrainerSpecFromFlag(bos_piece); SetTrainerSpecFromFlag(eos_piece); SetTrainerSpecFromFlag(pad_piece); SetTrainerSpecFromFlag(unk_surface); SetTrainerSpecFromFlag(required_chars); SetTrainerSpecFromFile(required_chars); SetTrainerSpecFromFlag(vocabulary_output_piece_score); SetRepeatedTrainerSpecFromFlag(accept_language); SetRepeatedTrainerSpecFromFlag(control_symbols); SetRepeatedTrainerSpecFromFlag(user_defined_symbols); SetTrainerSpecFromFlag(train_extremely_large_corpus); SetRepeatedTrainerSpecFromFile(control_symbols); SetRepeatedTrainerSpecFromFile(user_defined_symbols); normalizer_spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name)); SetNormalizerSpecFromFlag(normalization_rule_tsv); SetNormalizerSpecFromFlag(add_dummy_prefix); SetNormalizerSpecFromFlag(remove_extra_whitespaces); normalizer_spec.set_encode_case(absl::GetFlag(FLAGS_encode_unicode_case)); if (!absl::GetFlag(FLAGS_denormalization_rule_tsv).empty() || absl::GetFlag(FLAGS_encode_unicode_case)) { if(!absl::GetFlag(FLAGS_denormalization_rule_tsv).empty()) denormalizer_spec.set_normalization_rule_tsv(absl::GetFlag(FLAGS_denormalization_rule_tsv)); denormalizer_spec.set_add_dummy_prefix(false); denormalizer_spec.set_remove_extra_whitespaces(false); denormalizer_spec.set_escape_whitespaces(false); denormalizer_spec.set_decode_case(absl::GetFlag(FLAGS_encode_unicode_case)); } CHECK_OK(sentencepiece::SentencePieceTrainer::PopulateModelTypeFromString( absl::GetFlag(FLAGS_model_type), &trainer_spec)); CHECK_OK(sentencepiece::SentencePieceTrainer::Train( trainer_spec, normalizer_spec, denormalizer_spec)); return 0; }