Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-09 11:47:42 +0300
committerTaku Kudo <taku@google.com>2018-04-09 11:47:42 +0300
commitd1028974960d9e7ac9b408f6c212aa90d7c958cb (patch)
tree4cda91a55a068786d91e6d78afb294b494fd9e3c /src/sentencepiece_trainer.cc
parent8ff70f28bd33368af3a9d7c74b672a1d9bb01095 (diff)
Support to change ids of <unk>, <s>, </s>
Diffstat (limited to 'src/sentencepiece_trainer.cc')
-rw-r--r--src/sentencepiece_trainer.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 0f2873f..ad7d8ea 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -31,6 +31,8 @@ static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
DEFINE_string(input, "", "comma separated list of input sentences");
+DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
+ "Input format. Supported format is `text` or `tsv`.");
DEFINE_string(model_prefix, "", "output model prefix");
DEFINE_string(model_type, "unigram",
"model algorithm: unigram, bpe, word or char");
@@ -76,6 +78,14 @@ DEFINE_bool(remove_extra_whitespaces,
kDefaultNormalizerSpec.remove_extra_whitespaces(),
"Removes leading, trailing, and "
"duplicate internal whitespace");
+DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
+ "Override UNK (<unk>) id.");
+DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
+ "Override BOS (<s>) id. Set -1 to disable BOS.");
+DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
+ "Override EOS (</s>) id. Set -1 to disable EOS.");
+DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
+ "Override PAD (<pad>) id. Set -1 to disable PAD.");
using sentencepiece::NormalizerSpec;
using sentencepiece::TrainerSpec;
@@ -128,6 +138,7 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
CHECK_NOTNULL(trainer_spec);
CHECK_NOTNULL(normalizer_spec);
+ SetTrainerSpecFromFlag(input_format);
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
SetTrainerSpecFromFlag(character_coverage);
@@ -141,6 +152,10 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
SetTrainerSpecFromFlag(max_sentencepiece_length);
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
+ SetTrainerSpecFromFlag(unk_id);
+ SetTrainerSpecFromFlag(bos_id);
+ SetTrainerSpecFromFlag(eos_id);
+ SetTrainerSpecFromFlag(pad_id);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);