Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-16 10:22:16 +0300
committerTaku Kudo <taku@google.com>2018-06-16 10:22:16 +0300
commit511b83196807d6653fccde7e34bcefd737f587ca (patch)
tree06a2587b7324a192ce2fa0ef783402f87e7c2551 /src
parent4093e91909841671cc872f61e7a3ae8237b1fbc8 (diff)
Minor fixes.
Diffstat (limited to 'src')
-rw-r--r--src/sentencepiece_processor.cc15
-rw-r--r--src/sentencepiece_processor.h5
-rw-r--r--src/sentencepiece_processor_test.cc23
3 files changed, 40 insertions, 3 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 73ec2a1..88c3266 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -551,8 +551,12 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
// static
util::Status SentencePieceProcessor::ParseExtraOptions(
const std::string &extra_option,
- std::vector<SentencePieceProcessor::ExtraOption> *extra_options) {
+ std::vector<SentencePieceProcessor::ExtraOption> *extra_options) const {
extra_options->clear();
+ if (extra_option.empty()) return util::OkStatus();
+
+ RETURN_IF_ERROR(status());
+
static std::map<std::string, SentencePieceProcessor::ExtraOption>
extra_option_map = {{"bos", SentencePieceProcessor::BOS},
{"eos", SentencePieceProcessor::EOS},
@@ -562,6 +566,15 @@ util::Status SentencePieceProcessor::ParseExtraOptions(
CHECK_OR_RETURN(it != extra_option_map.end())
<< "option \"" << s << "\" is not available.";
extra_options->push_back(it->second);
+
+ if (it->second == SentencePieceProcessor::BOS) {
+ CHECK_OR_RETURN(!IsUnknown(PieceToId("<s>")))
+ << "id for `<s>` is not defined.";
+ }
+ if (it->second == SentencePieceProcessor::EOS) {
+ CHECK_OR_RETURN(!IsUnknown(PieceToId("</s>")))
+ << "id for `</s>` is not defined.";
+ }
}
return util::OkStatus();
}
diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h
index 5945111..846130b 100644
--- a/src/sentencepiece_processor.h
+++ b/src/sentencepiece_processor.h
@@ -312,8 +312,9 @@ class SentencePieceProcessor {
private:
enum ExtraOption { REVERSE, BOS, EOS };
- static util::Status ParseExtraOptions(
- const std::string &extra_option, std::vector<ExtraOption> *extra_options);
+ util::Status ParseExtraOptions(const std::string &extra_option,
+ std::vector<ExtraOption> *extra_options) const;
+
util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
SentencePieceText *spt) const;
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index ae65069..51a83db 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -805,6 +805,9 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
EXPECT_EQ("abc", output);
}
+ EXPECT_OK(sp.SetEncodeExtraOptions(""));
+ EXPECT_OK(sp.SetDecodeExtraOptions(""));
+
EXPECT_NOT_OK(sp.SetEncodeExtraOptions("foo"));
EXPECT_NOT_OK(sp.SetDecodeExtraOptions("foo"));
@@ -909,6 +912,26 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
}
}
+TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) {
+ ModelProto model_proto;
+ auto *sp1 = model_proto.add_pieces();
+
+ // No BOS/EOS.
+ sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
+ sp1->set_piece("<unk>");
+
+ AddPiece(&model_proto, "a", 0.0);
+ AddPiece(&model_proto, "b", 0.3);
+ AddPiece(&model_proto, "c", 0.2);
+ AddPiece(&model_proto, "ab", 1.0);
+
+ SentencePieceProcessor sp;
+ EXPECT_OK(sp.Load(model_proto));
+
+ EXPECT_NOT_OK(sp.SetEncodeExtraOptions("bos"));
+ EXPECT_NOT_OK(sp.SetDecodeExtraOptions("eos"));
+}
+
TEST(SentencePieceProcessorTest, VocabularyTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();