diff options
author | Taku Kudo <taku@google.com> | 2018-06-06 10:47:59 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-06 10:47:59 +0300 |
commit | e437e30bb478d5841e41feeb10346296448bff2b (patch) | |
tree | e568af539f7b3c3dca1a2c8ee0e6ee514c415954 /src/sentencepiece_processor_test.cc | |
parent | c6e84aebc903a84758afeafcbeea54c2bc3f641e (diff) |
Support vocab restriction feature
Diffstat (limited to 'src/sentencepiece_processor_test.cc')
-rw-r--r-- | src/sentencepiece_processor_test.cc | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 15a965c..3a6f689 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -807,5 +807,202 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_NOT_OK(sp.SetEncodeExtraOptions("foo")); EXPECT_NOT_OK(sp.SetDecodeExtraOptions("foo")); + + auto RunTest = [&model_proto](const SentencePieceProcessor &sp) { + EXPECT_EQ(model_proto.DebugString(), sp.model_proto().DebugString()); + + EXPECT_EQ(8, sp.GetPieceSize()); + EXPECT_EQ(0, sp.PieceToId("<unk>")); + EXPECT_EQ(1, sp.PieceToId("<s>")); + EXPECT_EQ(2, sp.PieceToId("</s>")); + EXPECT_EQ(3, sp.PieceToId("a")); + EXPECT_EQ(4, sp.PieceToId("b")); + EXPECT_EQ(5, sp.PieceToId("c")); + EXPECT_EQ(6, sp.PieceToId("ab")); + EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81")); + + EXPECT_EQ("<unk>", sp.IdToPiece(0)); + EXPECT_EQ("<s>", sp.IdToPiece(1)); + EXPECT_EQ("</s>", sp.IdToPiece(2)); + EXPECT_EQ("a", sp.IdToPiece(3)); + EXPECT_EQ("b", sp.IdToPiece(4)); + EXPECT_EQ("c", sp.IdToPiece(5)); + EXPECT_EQ("ab", sp.IdToPiece(6)); + EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7)); + + EXPECT_TRUE(sp.IsUnknown(0)); + EXPECT_FALSE(sp.IsUnknown(1)); + EXPECT_FALSE(sp.IsUnknown(2)); + EXPECT_FALSE(sp.IsUnknown(3)); + EXPECT_FALSE(sp.IsUnknown(4)); + EXPECT_FALSE(sp.IsUnknown(5)); + EXPECT_FALSE(sp.IsUnknown(6)); + EXPECT_FALSE(sp.IsUnknown(7)); + + EXPECT_FALSE(sp.IsControl(0)); + EXPECT_TRUE(sp.IsControl(1)); + EXPECT_TRUE(sp.IsControl(2)); + EXPECT_FALSE(sp.IsControl(3)); + EXPECT_FALSE(sp.IsControl(4)); + EXPECT_FALSE(sp.IsControl(5)); + EXPECT_FALSE(sp.IsControl(6)); + EXPECT_FALSE(sp.IsControl(7)); + + { + std::vector<std::string> sps; + const std::vector<std::string> expected_str = {WS, "ab", "c"}; + EXPECT_OK(sp.Encode("abc", &sps)); + EXPECT_EQ(expected_str, sps); + + std::vector<int> ids; + const std::vector<int> expected_id = {7, 6, 5}; + EXPECT_OK(sp.Encode("abc", &ids)); + EXPECT_EQ(expected_id, ids); + } + + { + std::string output; + const std::vector<std::string> sps = {"ab", "c"}; + EXPECT_OK(sp.Decode(sps, &output)); + EXPECT_EQ("abc", output); + + const std::vector<int> ids = {3, 4, 5}; + EXPECT_OK(sp.Decode(ids, &output)); + EXPECT_EQ("abc", output); + } + }; + + // Copies ModelProto. + { + SentencePieceProcessor sp; + const ModelProto copied = model_proto; + EXPECT_OK(sp.Load(copied)); + RunTest(sp); + } + + // Moves ModelProto. + { + SentencePieceProcessor sp; + auto moved = port::MakeUnique<ModelProto>(); + const ModelProto *moved_ptr = moved.get(); + *moved = model_proto; + EXPECT_OK(sp.Load(std::move(moved))); + EXPECT_EQ(moved_ptr, &sp.model_proto()); + RunTest(sp); + } + + // Restrict Vocabulary. + { + SentencePieceProcessor sp; + EXPECT_OK(sp.Load(model_proto)); + EXPECT_OK(sp.SetVocabulary({"a", "b", "c"})); // remove "ab" + + const std::vector<std::string> expected_str = {WS, "a", "b", "c"}; + std::vector<std::string> sps; + EXPECT_OK(sp.Encode("abc", &sps)); + EXPECT_EQ(expected_str, sps); + + std::vector<int> ids; + const std::vector<int> expected_id = {7, 3, 4, 5}; + EXPECT_OK(sp.Encode("abc", &ids)); + EXPECT_EQ(expected_id, ids); + } +} + +TEST(SentencePieceProcessorTest, VocabularyTest) { + ModelProto model_proto; + auto *sp1 = model_proto.add_pieces(); + auto *sp2 = model_proto.add_pieces(); + auto *sp3 = model_proto.add_pieces(); + + test::ScopedTempFile sf("vocab.txt"); + auto GetInlineFilename = [&sf](const std::string content) { + { + io::OutputBuffer out(sf.filename()); + out.Write(content); + } + return sf.filename(); + }; + + sp1->set_type(ModelProto::SentencePiece::UNKNOWN); + sp1->set_piece("<unk>"); + sp2->set_type(ModelProto::SentencePiece::CONTROL); + sp2->set_piece("<s>"); + sp3->set_type(ModelProto::SentencePiece::CONTROL); + sp3->set_piece("</s>"); + + AddPiece(&model_proto, "aa", 0.0); + AddPiece(&model_proto, "bb", 0.0); + AddPiece(&model_proto, "cc", 0.0); + AddPiece(&model_proto, "dd", 0.0); + AddPiece(&model_proto, "e", 0.0); + + SentencePieceProcessor sp; + EXPECT_OK(sp.Load(model_proto)); + + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_FALSE(sp.IsUnused(4)); + EXPECT_FALSE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.SetVocabulary({"aa", "dd", "e"})); + + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); // single char "e" is always used. + + EXPECT_OK(sp.ResetVocabulary()); + + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_FALSE(sp.IsUnused(4)); + EXPECT_FALSE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.SetVocabulary({"bb"})); + EXPECT_TRUE(sp.IsUnused(3)); + EXPECT_FALSE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_TRUE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t2\n"), 2)); + EXPECT_TRUE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 2)); + EXPECT_TRUE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_TRUE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 1)); + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\t0\ndd\t0\n"), 0)); + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); + + // No frequency. + EXPECT_OK(sp.LoadVocabulary(GetInlineFilename("aa\ndd\n"), 1)); + EXPECT_FALSE(sp.IsUnused(3)); + EXPECT_TRUE(sp.IsUnused(4)); + EXPECT_TRUE(sp.IsUnused(5)); + EXPECT_FALSE(sp.IsUnused(6)); + EXPECT_FALSE(sp.IsUnused(7)); } } // namespace sentencepiece |