diff options
author | Taku Kudo <taku@google.com> | 2018-06-06 10:47:59 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-06 10:47:59 +0300 |
commit | e437e30bb478d5841e41feeb10346296448bff2b (patch) | |
tree | e568af539f7b3c3dca1a2c8ee0e6ee514c415954 /src/spm_encode_main.cc | |
parent | c6e84aebc903a84758afeafcbeea54c2bc3f641e (diff) |
Support vocab restriction feature
Diffstat (limited to 'src/spm_encode_main.cc')
-rw-r--r-- | src/spm_encode_main.cc | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index 5e50005..ebaa066 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -29,6 +29,14 @@ DEFINE_string(extra_options, "", DEFINE_int32(nbest_size, 10, "NBest size"); DEFINE_double(alpha, 0.5, "Smoothing parameter for sampling mode."); +// Piece restriction with vocabulary file. +// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt +DEFINE_string(vocabulary, "", + "Restrict the vocabulary. The encoder only emits the " + "tokens in \"vocabulary\" file"); +DEFINE_int32(vocabulary_threshold, 0, + "Words with frequency < threshold will be treated as OOV"); + int main(int argc, char *argv[]) { std::vector<std::string> rest_args; sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args); @@ -39,6 +47,10 @@ int main(int argc, char *argv[]) { CHECK_OK(sp.Load(FLAGS_model)); CHECK_OK(sp.SetEncodeExtraOptions(FLAGS_extra_options)); + if (!FLAGS_vocabulary.empty()) { + CHECK_OK(sp.LoadVocabulary(FLAGS_vocabulary, FLAGS_vocabulary_threshold)); + } + sentencepiece::io::OutputBuffer output(FLAGS_output); CHECK_OK(output.status()); |