diff options
author | Taku Kudo <taku@google.com> | 2018-04-09 11:47:42 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-04-09 11:47:42 +0300 |
commit | d1028974960d9e7ac9b408f6c212aa90d7c958cb (patch) | |
tree | 4cda91a55a068786d91e6d78afb294b494fd9e3c /src/sentencepiece_trainer.cc | |
parent | 8ff70f28bd33368af3a9d7c74b672a1d9bb01095 (diff) |
Support to change ids of <unk>, <s>, </s>
Diffstat (limited to 'src/sentencepiece_trainer.cc')
-rw-r--r-- | src/sentencepiece_trainer.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 0f2873f..ad7d8ea 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -31,6 +31,8 @@ static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec; } // namespace DEFINE_string(input, "", "comma separated list of input sentences"); +DEFINE_string(input_format, kDefaultTrainerSpec.input_format(), + "Input format. Supported format is `text` or `tsv`."); DEFINE_string(model_prefix, "", "output model prefix"); DEFINE_string(model_type, "unigram", "model algorithm: unigram, bpe, word or char"); @@ -76,6 +78,14 @@ DEFINE_bool(remove_extra_whitespaces, kDefaultNormalizerSpec.remove_extra_whitespaces(), "Removes leading, trailing, and " "duplicate internal whitespace"); +DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(), + "Override UNK (<unk>) id."); +DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(), + "Override BOS (<s>) id. Set -1 to disable BOS."); +DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(), + "Override EOS (</s>) id. Set -1 to disable EOS."); +DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(), + "Override PAD (<pad>) id. Set -1 to disable PAD."); using sentencepiece::NormalizerSpec; using sentencepiece::TrainerSpec; @@ -128,6 +138,7 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec, CHECK_NOTNULL(trainer_spec); CHECK_NOTNULL(normalizer_spec); + SetTrainerSpecFromFlag(input_format); SetTrainerSpecFromFlag(model_prefix); SetTrainerSpecFromFlag(vocab_size); SetTrainerSpecFromFlag(character_coverage); @@ -141,6 +152,10 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec, SetTrainerSpecFromFlag(max_sentencepiece_length); SetTrainerSpecFromFlag(split_by_unicode_script); SetTrainerSpecFromFlag(split_by_whitespace); + SetTrainerSpecFromFlag(unk_id); + SetTrainerSpecFromFlag(bos_id); + SetTrainerSpecFromFlag(eos_id); + SetTrainerSpecFromFlag(pad_id); SetRepeatedTrainerSpecFromFlag(accept_language); SetRepeatedTrainerSpecFromFlag(control_symbols); SetRepeatedTrainerSpecFromFlag(user_defined_symbols); |