Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-06 18:30:16 +0300
committerTaku Kudo <taku@google.com>2018-06-06 18:30:16 +0300
commit53ea373084aa082d54d47f98c4697a88e7dfec30 (patch)
treed4b0eae4aa6a7838d49316462b68098d3fc99cc0 /src/spm_encode_main.cc
parentfc7130f39450f3f18fcc91c8d0122810eb85598d (diff)
Added --generate_vocabulary option to spm_encode
Diffstat (limited to 'src/spm_encode_main.cc')
-rw-r--r--src/spm_encode_main.cc27
1 files changed, 22 insertions, 5 deletions
diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc
index ebaa066..5560b93 100644
--- a/src/spm_encode_main.cc
+++ b/src/spm_encode_main.cc
@@ -13,10 +13,12 @@
// limitations under the License.!
#include <functional>
+#include <unordered_map>
#include "common.h"
#include "flags.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_processor.h"
+#include "trainer_interface.h"
#include "util.h"
DEFINE_string(model, "", "model file name");
@@ -37,6 +39,9 @@ DEFINE_string(vocabulary, "",
DEFINE_int32(vocabulary_threshold, 0,
"Words with frequency < threshold will be treated as OOV");
+DEFINE_bool(generate_vocabulary, false,
+ "Generates vocabulary file instead of segmentation");
+
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
@@ -63,11 +68,20 @@ int main(int argc, char *argv[]) {
std::vector<int> ids;
std::vector<std::vector<std::string>> nbest_sps;
std::vector<std::vector<int>> nbest_ids;
+ std::unordered_map<std::string, int> vocab;
sentencepiece::SentencePieceText spt;
sentencepiece::NBestSentencePieceText nbest_spt;
std::function<void(const std::string &line)> process;
- if (FLAGS_output_format == "piece") {
+ if (FLAGS_generate_vocabulary) {
+ process = [&](const std::string &line) {
+ CHECK_OK(sp.Encode(line, &spt));
+ for (const auto &piece : spt.pieces()) {
+ if (!sp.IsUnknown(piece.id()) && !sp.IsControl(piece.id()))
+ vocab[piece.piece()]++;
+ }
+ };
+ } else if (FLAGS_output_format == "piece") {
process = [&](const std::string &line) {
CHECK_OK(sp.Encode(line, &sps));
output.WriteLine(sentencepiece::string_util::Join(sps, " "));
@@ -124,13 +138,16 @@ int main(int argc, char *argv[]) {
sentencepiece::io::InputBuffer input(filename);
CHECK_OK(input.status());
while (input.ReadLine(&line)) {
- if (line.empty()) {
- output.WriteLine("");
- continue;
- }
process(line);
}
}
+ if (FLAGS_generate_vocabulary) {
+ for (const auto &it : sentencepiece::Sorted(vocab)) {
+ output.WriteLine(it.first + "\t" +
+ sentencepiece::string_util::SimpleItoa(it.second));
+ }
+ }
+
return 0;
}