From 2928ce5307224ea4c012fc6cbd7a098c486590b6 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Tue, 7 Mar 2017 19:43:50 +0900 Subject: Initialize repository --- src/spm_decode_main.cc | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 src/spm_decode_main.cc (limited to 'src/spm_decode_main.cc') diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc new file mode 100644 index 0000000..7166e1f --- /dev/null +++ b/src/spm_decode_main.cc @@ -0,0 +1,97 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +#include "common.h" +#include "flags.h" +#include "sentencepiece.pb.h" +#include "sentencepiece_processor.h" +#include "util.h" + +DEFINE_string(model, "", "model file name"); +DEFINE_string(output, "", "output filename"); +DEFINE_string(input_format, "piece", "choose from piece or id"); +DEFINE_string(output_format, "string", "choose from string or proto"); +DEFINE_string(extra_options, "", + "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); + +int main(int argc, char *argv[]) { + std::vector rest_args; + sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args); + + CHECK_OR_HELP(model); + + sentencepiece::SentencePieceProcessor sp; + sp.LoadOrDie(FLAGS_model); + sp.SetDecodeExtraOptions(FLAGS_extra_options); + + sentencepiece::io::OutputBuffer output(FLAGS_output); + + if (rest_args.empty()) { + rest_args.push_back(""); // empty means that reading from stdin. + } + + std::string detok, line; + sentencepiece::SentencePieceText spt; + std::function &pieces)> process; + + auto ToIds = [&](const std::vector &pieces) { + std::vector ids; + for (const auto &s : pieces) { + ids.push_back(atoi(s.c_str())); + } + return ids; + }; + + if (FLAGS_input_format == "piece") { + if (FLAGS_output_format == "string") { + process = [&](const std::vector &pieces) { + sp.Decode(pieces, &detok); + output.WriteLine(detok); + }; + } else if (FLAGS_output_format == "proto") { + process = [&](const std::vector &pieces) { + sp.Decode(pieces, &spt); + output.WriteLine(spt.Utf8DebugString()); + }; + } else { + LOG(FATAL) << "Unknown output format: " << FLAGS_output_format; + } + } else if (FLAGS_input_format == "id") { + if (FLAGS_output_format == "string") { + process = [&](const std::vector &pieces) { + sp.Decode(ToIds(pieces), &detok); + output.WriteLine(detok); + }; + } else if (FLAGS_output_format == "proto") { + process = [&](const std::vector &pieces) { + sp.Decode(ToIds(pieces), &spt); + output.WriteLine(spt.Utf8DebugString()); + }; + } else { + LOG(FATAL) << "Unknown output format: " << FLAGS_output_format; + } + } else { + LOG(FATAL) << "Unknown input format: " << FLAGS_input_format; + } + + for (const auto &filename : rest_args) { + sentencepiece::io::InputBuffer input(filename); + while (input.ReadLine(&line)) { + const auto pieces = sentencepiece::string_util::Split(line, " "); + process(pieces); + } + } + + return 0; +} -- cgit v1.2.3