Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-06 10:47:59 +0300
committerTaku Kudo <taku@google.com>2018-06-06 10:47:59 +0300
commite437e30bb478d5841e41feeb10346296448bff2b (patch)
treee568af539f7b3c3dca1a2c8ee0e6ee514c415954 /src/sentencepiece_processor_test.cc
parentc6e84aebc903a84758afeafcbeea54c2bc3f641e (diff)
Support vocab restriction feature
Diffstat (limited to 'src/sentencepiece_processor_test.cc')
-rw-r--r--src/sentencepiece_processor_test.cc197
1 files changed, 197 insertions, 0 deletions
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 15a965c..3a6f689 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -807,5 +807,202 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
EXPECT_NOT_OK(sp.SetEncodeExtraOptions("foo"));
EXPECT_NOT_OK(sp.SetDecodeExtraOptions("foo"));
+
+ auto RunTest = [&model_proto](const SentencePieceProcessor &sp) {
+ EXPECT_EQ(model_proto.DebugString(), sp.model_proto().DebugString());
+
+ EXPECT_EQ(8, sp.GetPieceSize());
+ EXPECT_EQ(0, sp.PieceToId("<unk>"));
+ EXPECT_EQ(1, sp.PieceToId("<s>"));
+ EXPECT_EQ(2, sp.PieceToId("</s>"));
+ EXPECT_EQ(3, sp.PieceToId("a"));
+ EXPECT_EQ(4, sp.PieceToId("b"));
+ EXPECT_EQ(5, sp.PieceToId("c"));
+ EXPECT_EQ(6, sp.PieceToId("ab"));
+ EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81"));
+
+ EXPECT_EQ("<unk>", sp.IdToPiece(0));
+ EXPECT_EQ("<s>", sp.IdToPiece(1));
+ EXPECT_EQ("</s>", sp.IdToPiece(2));
+ EXPECT_EQ("a", sp.IdToPiece(3));
+ EXPECT_EQ("b", sp.IdToPiece(4));
+ EXPECT_EQ("c", sp.IdToPiece(5));
+ EXPECT_EQ("ab", sp.IdToPiece(6));
+ EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7));
+
+ EXPECT_TRUE(sp.IsUnknown(0));
+ EXPECT_FALSE(sp.IsUnknown(1));
+ EXPECT_FALSE(sp.IsUnknown(2));
+ EXPECT_FALSE(sp.IsUnknown(3));
+ EXPECT_FALSE(sp.IsUnknown(4));
+ EXPECT_FALSE(sp.IsUnknown(5));
+ EXPECT_FALSE(sp.IsUnknown(6));
+ EXPECT_FALSE(sp.IsUnknown(7));
+
+ EXPECT_FALSE(sp.IsControl(0));
+ EXPECT_TRUE(sp.IsControl(1));
+ EXPECT_TRUE(sp.IsControl(2));
+ EXPECT_FALSE(sp.IsControl(3));
+ EXPECT_FALSE(sp.IsControl(4));
+ EXPECT_FALSE(sp.IsControl(5));
+ EXPECT_FALSE(sp.IsControl(6));
+ EXPECT_FALSE(sp.IsControl(7));
+
+ {
+ std::vector<std::string> sps;
+ const std::vector<std::string> expected_str = {WS, "ab", "c"};
+ EXPECT_OK(sp.Encode("abc", &sps));
+ EXPECT_EQ(expected_str, sps);
+
+ std::vector<int> ids;
+ const std::vector<int> expected_id = {7, 6, 5};
+ EXPECT_OK(sp.Encode("abc", &ids));
+ EXPECT_EQ(expected_id, ids);
+ }
+
+ {
+ std::string output;
+ const std::vector<std::string> sps = {"ab", "c"};
+ EXPECT_OK(sp.Decode(sps, &output));
+ EXPECT_EQ("abc", output);
+
+ const std::vector<int> ids = {3, 4, 5};
+ EXPECT_OK(sp.Decode(ids, &output));
+ EXPECT_EQ("abc", output);
+ }
+ };
+
+ // Copies ModelProto.
+ {
+ SentencePieceProcessor sp;
+ const ModelProto copied = model_proto;
+ EXPECT_OK(sp.Load(copied));
+ RunTest(sp);
+ }
+
+ // Moves ModelProto.
+ {
+ SentencePieceProcessor sp;
+ auto moved = port::MakeUnique<ModelProto>();
+ const ModelProto *moved_ptr = moved.get();
+ *moved = model_proto;
+ EXPECT_OK(sp.Load(std::move(moved)));
+ EXPECT_EQ(moved_ptr, &sp.model_proto());
+ RunTest(sp);
+ }
+
+ // Restrict Vocabulary.
+ {
+ SentencePieceProcessor sp;
+ EXPECT_OK(sp.Load(model_proto));
+ EXPECT_OK(sp.SetVocabulary({"a", "b", "c"})); // remove "ab"
+
+ const std::vector<std::string> expected_str = {WS, "a", "b", "c"};
+ std::vector<std::string> sps;
+ EXPECT_OK(sp.Encode("abc", &sps));
+ EXPECT_EQ(expected_str, sps);
+
+ std::vector<int> ids;
+ const std::vector<int> expected_id = {7, 3, 4, 5};
+ EXPECT_OK(sp.Encode("abc", &ids));
+ EXPECT_EQ(expected_id, ids);
+ }
+}
+
+TEST(SentencePieceProcessorTest, VocabularyTest) {
+ ModelProto model_proto;
+ auto *sp1 = model_proto.add_pieces();
+ auto *sp2 = model_proto.add_pieces();
+ auto *sp3 = model_proto.add_pieces();
+
+ test::ScopedTempFile sf("vocab.txt");
+ auto GetInlineFilename = [&sf](const std::string content) {
+ {
+ io::OutputBuffer out(sf.filename());
+ out.Write(content);
+ }
+ return sf.filename();
+ };
+
+ sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
+ sp1->set_piece("<unk>");
+ sp2->set_type(ModelProto::SentencePiece::CONTROL);
+ sp2->set_piece("<s>");
+ sp3->set_type(ModelProto::SentencePiece::CONTROL);
+ sp3->set_piece("</s>");
+
+ AddPiece(&model_proto, "aa", 0.0);
+ AddPiece(&model_proto, "bb", 0.0);
+ AddPiece(&model_proto, "cc", 0.0);
+ AddPiece(&model_proto, "dd", 0.0);
+ AddPiece(&model_proto, "e", 0.0);
+
+ SentencePieceProcessor sp;
+ EXPECT_OK(sp.Load(model_proto));
+
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_FALSE(sp.IsUnused(4));
+ EXPECT_FALSE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.SetVocabulary({"aa", "dd", "e"}));
+
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7)); // single char "e" is always used.
+
+ EXPECT_OK(sp.ResetVocabulary());
+
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_FALSE(sp.IsUnused(4));
+ EXPECT_FALSE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.SetVocabulary({"bb"}));
+ EXPECT_TRUE(sp.IsUnused(3));
+ EXPECT_FALSE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_TRUE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t2\n"), 2));
+ EXPECT_TRUE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 2));
+ EXPECT_TRUE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_TRUE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 1));
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t0\ndd\t0\n"), 0));
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
+
+ // No frequency.
+ EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\ndd\n"), 1));
+ EXPECT_FALSE(sp.IsUnused(3));
+ EXPECT_TRUE(sp.IsUnused(4));
+ EXPECT_TRUE(sp.IsUnused(5));
+ EXPECT_FALSE(sp.IsUnused(6));
+ EXPECT_FALSE(sp.IsUnused(7));
}
} // namespace sentencepiece