Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
committerTaku Kudo <taku@google.com>2018-04-30 18:00:35 +0300
commitf228e556b484015cdb29c467416dd95ca38360cd (patch)
tree5b2b31e2a4c81d9f701610ecdfcaa4309d52b5d1 /src/spm_train_main.cc
parent36a3b35e17f995c5fb77c0397122b0d6d41fef1b (diff)
Reimplement Trainer with Proto reflection
Diffstat (limited to 'src/spm_train_main.cc')
-rw-r--r--src/spm_train_main.cc132
1 files changed, 131 insertions, 1 deletions
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index 6f9f917..761cc15 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -12,10 +12,140 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "builder.h"
+#include "flags.h"
#include "sentencepiece_trainer.h"
+#include "util.h"
+
+using sentencepiece::NormalizerSpec;
+using sentencepiece::TrainerSpec;
+using sentencepiece::normalizer::Builder;
+
+namespace {
+static sentencepiece::TrainerSpec kDefaultTrainerSpec;
+static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
+} // namespace
+
+DEFINE_string(input, "", "comma separated list of input sentences");
+DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
+ "Input format. Supported format is `text` or `tsv`.");
+DEFINE_string(model_prefix, "", "output model prefix");
+DEFINE_string(model_type, "unigram",
+ "model algorithm: unigram, bpe, word or char");
+DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
+DEFINE_string(accept_language, "",
+ "comma-separated list of languages this model can accept");
+DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
+ "character coverage to determine the minimum symbols");
+DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
+ "maximum size of sentences the trainer loads");
+DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
+ "maximum size of sentences to make seed sentence piece");
+DEFINE_int32(training_sentence_size,
+ kDefaultTrainerSpec.training_sentence_size(),
+ "maximum size of sentences to train sentence pieces");
+DEFINE_int32(seed_sentencepiece_size,
+ kDefaultTrainerSpec.seed_sentencepiece_size(),
+ "the size of seed sentencepieces");
+DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
+ "Keeps top shrinking_factor pieces with respect to the loss");
+DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
+ "number of threads for training");
+DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
+ "number of EM sub-iterations");
+DEFINE_int32(max_sentencepiece_length,
+ kDefaultTrainerSpec.max_sentencepiece_length(),
+ "maximum length of sentence piece");
+DEFINE_bool(split_by_unicode_script,
+ kDefaultTrainerSpec.split_by_unicode_script(),
+ "use Unicode script to split sentence pieces");
+DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
+ "use a white space to split sentence pieces");
+DEFINE_string(control_symbols, "", "comma separated list of control symbols");
+DEFINE_string(user_defined_symbols, "",
+ "comma separated list of user defined symbols");
+DEFINE_string(normalization_rule_name, "nfkc",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
+DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
+DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
+ "Add dummy whitespace at the beginning of text");
+DEFINE_bool(remove_extra_whitespaces,
+ kDefaultNormalizerSpec.remove_extra_whitespaces(),
+ "Removes leading, trailing, and "
+ "duplicate internal whitespace");
+DEFINE_bool(hard_vocab_limit, kDefaultTrainerSpec.hard_vocab_limit(),
+ "If set to false, --vocab_size is considered as a soft limit.");
+DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), "Override UNK (<unk>) id.");
+DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
+ "Override BOS (<s>) id. Set -1 to disable BOS.");
+DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
+ "Override EOS (</s>) id. Set -1 to disable EOS.");
+DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
+ "Override PAD (<pad>) id. Set -1 to disable PAD.");
int main(int argc, char *argv[]) {
- sentencepiece::SentencePieceTrainer::Train(argc, argv);
+ sentencepiece::flags::ParseCommandLineFlags(argc, argv);
+ sentencepiece::TrainerSpec trainer_spec;
+ sentencepiece::NormalizerSpec normalizer_spec;
+
+ CHECK_OR_HELP(input);
+ CHECK_OR_HELP(model_prefix);
+
+// Populates the value from flags to spec.
+#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
+
+#define SetNormalizerSpecFromFlag(name) \
+ normalizer_spec.set_##name(FLAGS_##name);
+
+#define SetRepeatedTrainerSpecFromFlag(name) \
+ if (!FLAGS_##name.empty()) { \
+ for (const auto v : \
+ sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
+ trainer_spec.add_##name(v); \
+ } \
+ }
+
+ SetTrainerSpecFromFlag(input_format);
+ SetTrainerSpecFromFlag(model_prefix);
+ SetTrainerSpecFromFlag(vocab_size);
+ SetTrainerSpecFromFlag(character_coverage);
+ SetTrainerSpecFromFlag(input_sentence_size);
+ SetTrainerSpecFromFlag(mining_sentence_size);
+ SetTrainerSpecFromFlag(training_sentence_size);
+ SetTrainerSpecFromFlag(seed_sentencepiece_size);
+ SetTrainerSpecFromFlag(shrinking_factor);
+ SetTrainerSpecFromFlag(num_threads);
+ SetTrainerSpecFromFlag(num_sub_iterations);
+ SetTrainerSpecFromFlag(max_sentencepiece_length);
+ SetTrainerSpecFromFlag(split_by_unicode_script);
+ SetTrainerSpecFromFlag(split_by_whitespace);
+ SetTrainerSpecFromFlag(hard_vocab_limit);
+ SetTrainerSpecFromFlag(unk_id);
+ SetTrainerSpecFromFlag(bos_id);
+ SetTrainerSpecFromFlag(eos_id);
+ SetTrainerSpecFromFlag(pad_id);
+ SetRepeatedTrainerSpecFromFlag(input);
+ SetRepeatedTrainerSpecFromFlag(accept_language);
+ SetRepeatedTrainerSpecFromFlag(control_symbols);
+ SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
+
+ normalizer_spec.set_name(FLAGS_normalization_rule_name);
+ SetNormalizerSpecFromFlag(normalization_rule_tsv);
+ SetNormalizerSpecFromFlag(add_dummy_prefix);
+ SetNormalizerSpecFromFlag(remove_extra_whitespaces);
+
+ const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
+ {"unigram", TrainerSpec::UNIGRAM},
+ {"bpe", TrainerSpec::BPE},
+ {"word", TrainerSpec::WORD},
+ {"char", TrainerSpec::CHAR}};
+
+ trainer_spec.set_model_type(
+ sentencepiece::port::FindOrDie(kModelTypeMap, FLAGS_model_type));
+
+ CHECK_OK(sentencepiece::SentencePieceTrainer::Train(trainer_spec,
+ normalizer_spec));
return 0;
}