Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-02-28 13:54:11 +0300
committerTaku Kudo <taku@google.com>2018-02-28 13:54:11 +0300
commit45b4527117c5bf52b9bb14c33de9ec7facae9c93 (patch)
tree3f185406804145defb1a8622c6347f33cd5931de /src/sentencepiece_trainer.cc
parentc6a1a196651789ba4c0334dbf41d5885b3334b2f (diff)
Added SentencePieceTrainer class
Diffstat (limited to 'src/sentencepiece_trainer.cc')
-rw-r--r--src/sentencepiece_trainer.cc202
1 files changed, 202 insertions, 0 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
new file mode 100644
index 0000000..0f2873f
--- /dev/null
+++ b/src/sentencepiece_trainer.cc
@@ -0,0 +1,202 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#include "sentencepiece_trainer.h"
+#include <mutex>
+#include <string>
+
+#include "builder.h"
+#include "common.h"
+#include "flags.h"
+#include "normalizer.h"
+#include "sentencepiece.pb.h"
+#include "sentencepiece_model.pb.h"
+#include "trainer_factory.h"
+#include "util.h"
+
+namespace {
+static const sentencepiece::TrainerSpec kDefaultTrainerSpec;
+static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
+} // namespace
+
+DEFINE_string(input, "", "comma separated list of input sentences");
+DEFINE_string(model_prefix, "", "output model prefix");
+DEFINE_string(model_type, "unigram",
+ "model algorithm: unigram, bpe, word or char");
+DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
+DEFINE_string(accept_language, "",
+ "comma-separated list of languages this model can accept");
+DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
+ "character coverage to determine the minimum symbols");
+DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
+ "maximum size of sentences the trainer loads");
+DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
+ "maximum size of sentences to make seed sentence piece");
+DEFINE_int32(training_sentence_size,
+ kDefaultTrainerSpec.training_sentence_size(),
+ "maximum size of sentences to train sentence pieces");
+DEFINE_int32(seed_sentencepiece_size,
+ kDefaultTrainerSpec.seed_sentencepiece_size(),
+ "the size of seed sentencepieces");
+DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
+ "Keeps top shrinking_factor pieces with respect to the loss");
+DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
+ "number of threads for training");
+DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
+ "number of EM sub-iterations");
+DEFINE_int32(max_sentencepiece_length,
+ kDefaultTrainerSpec.max_sentencepiece_length(),
+ "maximum length of sentence piece");
+DEFINE_bool(split_by_unicode_script,
+ kDefaultTrainerSpec.split_by_unicode_script(),
+ "use Unicode script to split sentence pieces");
+DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
+ "use a white space to split sentence pieces");
+DEFINE_string(control_symbols, "", "comma separated list of control symbols");
+DEFINE_string(user_defined_symbols, "",
+ "comma separated list of user defined symbols");
+DEFINE_string(normalization_rule_name, "nfkc",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
+DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
+DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
+ "Add dummy whitespace at the beginning of text");
+DEFINE_bool(remove_extra_whitespaces,
+ kDefaultNormalizerSpec.remove_extra_whitespaces(),
+ "Removes leading, trailing, and "
+ "duplicate internal whitespace");
+
+using sentencepiece::NormalizerSpec;
+using sentencepiece::TrainerSpec;
+using sentencepiece::normalizer::Builder;
+
+namespace sentencepiece {
+namespace {
+static constexpr char kDefaultNormalizerName[] = "nfkc";
+
+NormalizerSpec MakeNormalizerSpecFromFlags() {
+ if (!FLAGS_normalization_rule_tsv.empty()) {
+ const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
+ FLAGS_normalization_rule_tsv);
+ sentencepiece::NormalizerSpec spec;
+ spec.set_name("user_defined");
+ spec.set_precompiled_charsmap(
+ sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
+ return spec;
+ }
+
+ return sentencepiece::normalizer::Builder::GetNormalizerSpec(
+ FLAGS_normalization_rule_name);
+}
+
+TrainerSpec::ModelType GetModelTypeFromString(const std::string &type) {
+ const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
+ {"unigram", TrainerSpec::UNIGRAM},
+ {"bpe", TrainerSpec::BPE},
+ {"word", TrainerSpec::WORD},
+ {"char", TrainerSpec::CHAR}};
+ return port::FindOrDie(kModelTypeMap, type);
+}
+
+// Populates the value from flags to spec.
+#define SetTrainerSpecFromFlag(name) trainer_spec->set_##name(FLAGS_##name);
+
+#define SetNormalizerSpecFromFlag(name) \
+ normalizer_spec->set_##name(FLAGS_##name);
+
+#define SetRepeatedTrainerSpecFromFlag(name) \
+ if (!FLAGS_##name.empty()) { \
+ for (const auto v : \
+ sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
+ trainer_spec->add_##name(v); \
+ } \
+ }
+
+void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
+ NormalizerSpec *normalizer_spec) {
+ CHECK_NOTNULL(trainer_spec);
+ CHECK_NOTNULL(normalizer_spec);
+
+ SetTrainerSpecFromFlag(model_prefix);
+ SetTrainerSpecFromFlag(vocab_size);
+ SetTrainerSpecFromFlag(character_coverage);
+ SetTrainerSpecFromFlag(input_sentence_size);
+ SetTrainerSpecFromFlag(mining_sentence_size);
+ SetTrainerSpecFromFlag(training_sentence_size);
+ SetTrainerSpecFromFlag(seed_sentencepiece_size);
+ SetTrainerSpecFromFlag(shrinking_factor);
+ SetTrainerSpecFromFlag(num_threads);
+ SetTrainerSpecFromFlag(num_sub_iterations);
+ SetTrainerSpecFromFlag(max_sentencepiece_length);
+ SetTrainerSpecFromFlag(split_by_unicode_script);
+ SetTrainerSpecFromFlag(split_by_whitespace);
+ SetRepeatedTrainerSpecFromFlag(accept_language);
+ SetRepeatedTrainerSpecFromFlag(control_symbols);
+ SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
+
+ *normalizer_spec = MakeNormalizerSpecFromFlags();
+ SetNormalizerSpecFromFlag(add_dummy_prefix);
+ SetNormalizerSpecFromFlag(remove_extra_whitespaces);
+
+ for (const auto &filename :
+ sentencepiece::string_util::Split(FLAGS_input, ",")) {
+ trainer_spec->add_input(filename);
+ }
+
+ trainer_spec->set_model_type(GetModelTypeFromString(FLAGS_model_type));
+}
+} // namespace
+
+// static
+void SentencePieceTrainer::Train(int argc, char **argv) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+ {
+ static std::mutex flags_mutex;
+ std::lock_guard<std::mutex> lock(flags_mutex);
+ sentencepiece::flags::ParseCommandLineFlags(argc, argv);
+ CHECK_OR_HELP(input);
+ CHECK_OR_HELP(model_prefix);
+ MakeTrainerSpecFromFlags(&trainer_spec, &normalizer_spec);
+ }
+
+ SentencePieceTrainer::Train(trainer_spec, normalizer_spec);
+}
+
+// static
+void SentencePieceTrainer::Train(const std::string &arg) {
+ const std::vector<std::string> args =
+ sentencepiece::string_util::Split(arg, " ");
+ std::vector<char *> cargs(args.size() + 1);
+ cargs[0] = const_cast<char *>("");
+ for (size_t i = 0; i < args.size(); ++i)
+ cargs[i + 1] = const_cast<char *>(args[i].data());
+ SentencePieceTrainer::Train(static_cast<int>(cargs.size()), &cargs[0]);
+}
+
+// static
+void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
+ SentencePieceTrainer::Train(
+ trainer_spec,
+ normalizer::Builder::GetNormalizerSpec(kDefaultNormalizerName));
+}
+
+// static
+void SentencePieceTrainer::Train(const TrainerSpec &trainer_spec,
+ const NormalizerSpec &normalizer_spec) {
+ auto trainer = TrainerFactory::Create(trainer_spec, normalizer_spec);
+ trainer->Train();
+}
+
+} // namespace sentencepiece