diff options
author | Taku Kudo <taku@google.com> | 2018-06-06 18:30:16 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-06 18:30:16 +0300 |
commit | 53ea373084aa082d54d47f98c4697a88e7dfec30 (patch) | |
tree | d4b0eae4aa6a7838d49316462b68098d3fc99cc0 /src/spm_encode_main.cc | |
parent | fc7130f39450f3f18fcc91c8d0122810eb85598d (diff) |
Added --generate_vocabulary option to spm_encode
Diffstat (limited to 'src/spm_encode_main.cc')
-rw-r--r-- | src/spm_encode_main.cc | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index ebaa066..5560b93 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -13,10 +13,12 @@ // limitations under the License.! #include <functional> +#include <unordered_map> #include "common.h" #include "flags.h" #include "sentencepiece.pb.h" #include "sentencepiece_processor.h" +#include "trainer_interface.h" #include "util.h" DEFINE_string(model, "", "model file name"); @@ -37,6 +39,9 @@ DEFINE_string(vocabulary, "", DEFINE_int32(vocabulary_threshold, 0, "Words with frequency < threshold will be treated as OOV"); +DEFINE_bool(generate_vocabulary, false, + "Generates vocabulary file instead of segmentation"); + int main(int argc, char *argv[]) { std::vector<std::string> rest_args; sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args); @@ -63,11 +68,20 @@ int main(int argc, char *argv[]) { std::vector<int> ids; std::vector<std::vector<std::string>> nbest_sps; std::vector<std::vector<int>> nbest_ids; + std::unordered_map<std::string, int> vocab; sentencepiece::SentencePieceText spt; sentencepiece::NBestSentencePieceText nbest_spt; std::function<void(const std::string &line)> process; - if (FLAGS_output_format == "piece") { + if (FLAGS_generate_vocabulary) { + process = [&](const std::string &line) { + CHECK_OK(sp.Encode(line, &spt)); + for (const auto &piece : spt.pieces()) { + if (!sp.IsUnknown(piece.id()) && !sp.IsControl(piece.id())) + vocab[piece.piece()]++; + } + }; + } else if (FLAGS_output_format == "piece") { process = [&](const std::string &line) { CHECK_OK(sp.Encode(line, &sps)); output.WriteLine(sentencepiece::string_util::Join(sps, " ")); @@ -124,13 +138,16 @@ int main(int argc, char *argv[]) { sentencepiece::io::InputBuffer input(filename); CHECK_OK(input.status()); while (input.ReadLine(&line)) { - if (line.empty()) { - output.WriteLine(""); - continue; - } process(line); } } + if (FLAGS_generate_vocabulary) { + for (const auto &it : sentencepiece::Sorted(vocab)) { + output.WriteLine(it.first + "\t" + + sentencepiece::string_util::SimpleItoa(it.second)); + } + } + return 0; } |