diff options
author | Taku Kudo <taku@google.com> | 2018-02-28 07:14:52 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-02-28 07:14:52 +0300 |
commit | c6a1a196651789ba4c0334dbf41d5885b3334b2f (patch) | |
tree | a6e2b3a0e7a71b9d4d21e9f8800dc9b4b3cea94b /src/sentencepiece_processor_test.cc | |
parent | ab766cbdaac1332776ae2c457fed9380f500159b (diff) |
Add Sample/NBestEncode
Diffstat (limited to 'src/sentencepiece_processor_test.cc')
-rw-r--r-- | src/sentencepiece_processor_test.cc | 136 |
1 files changed, 115 insertions, 21 deletions
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 80ebead..dce8160 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -31,18 +31,32 @@ using port::MakeUnique; class MockModel : public ModelInterface { public: - void SetEncodeResult(StringPiece input, - const std::vector<std::pair<StringPiece, int>> &output) { + void SetEncodeResult(StringPiece input, const EncodeResult &output) { input_ = input; output_ = output; } - std::vector<std::pair<StringPiece, int>> Encode( - StringPiece normalized) const { + void SetNBestEncodeResult(StringPiece input, + const NBestEncodeResult &output) { + input_ = input; + nbest_output_ = output; + } + + EncodeResult Encode(StringPiece normalized) const { + EXPECT_EQ(normalized, input_); + return output_; + } + + EncodeResult SampleEncode(StringPiece normalized, float alpha) const { EXPECT_EQ(normalized, input_); return output_; } + NBestEncodeResult NBestEncode(StringPiece normalized, int nbest_size) const { + EXPECT_EQ(normalized, input_); + return nbest_output_; + } + bool IsControl(int id) const { return id == 1 || id == 2; } bool IsUnknown(int id) const { return id == 0; } @@ -57,11 +71,11 @@ class MockModel : public ModelInterface { private: StringPiece input_; - std::vector<std::pair<StringPiece, int>> output_; + EncodeResult output_; + NBestEncodeResult nbest_output_; }; -std::vector<std::string> GetSpVec( - const std::vector<std::pair<StringPiece, int>> &pieces) { +std::vector<std::string> GetSpVec(const EncodeResult &pieces) { std::vector<std::string> sps; for (const auto &p : pieces) { sps.emplace_back(p.first.to_string()); @@ -90,7 +104,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { + const EncodeResult result = { {WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}; mock->SetEncodeResult(kInput, result); @@ -132,9 +146,9 @@ TEST(SentencepieceProcessorTest, EncodeTest) { { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { + const EncodeResult result = { {WS "ABC", 3}, {WS "D", 4}, {"E", 0}, {"F", 0}, {"</s>", 2}}; - const std::vector<std::pair<StringPiece, int>> expected = { + const EncodeResult expected = { {WS "ABC", 3}, {WS "D", 4}, {"EF", 0}, {"</s>", 2}}; mock->SetEncodeResult(kInput, result); @@ -176,7 +190,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { // ModelInterface::Encode() returns shorter results. { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = {{WS "ABC", 3}}; + const EncodeResult result = {{WS "ABC", 3}}; mock->SetEncodeResult(kInput, result); sp.SetModel(std::move(mock)); sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec)); @@ -189,7 +203,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { // ModelInterface::Encode() returns longer results. { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { + const EncodeResult result = { {WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G", 6}}; mock->SetEncodeResult(kInput, result); sp.SetModel(std::move(mock)); @@ -203,7 +217,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { // ModelInterface::Encode() returns an empty piece. { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { + const EncodeResult result = { {WS "ABC", 3}, {WS "DE", 4}, {"", 5}, {"F", 6}}; mock->SetEncodeResult(kInput, result); sp.SetModel(std::move(mock)); @@ -216,8 +230,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { // Halfwidth to Fullwidith katakana normalization. { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { - {WS "グー", 3}, {"グル", 4}, {"</s>", 2}}; + const EncodeResult result = {{WS "グー", 3}, {"グル", 4}, {"</s>", 2}}; const StringPiece input = WS "グーグル"; mock->SetEncodeResult(input, result); sp.SetModel(std::move(mock)); @@ -251,8 +264,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) { // One to many normalization. { auto mock = MakeUnique<MockModel>(); - const std::vector<std::pair<StringPiece, int>> result = { - {WS "株式", 3}, {"会社", 4}, {"</s>", 2}}; + const EncodeResult result = {{WS "株式", 3}, {"会社", 4}, {"</s>", 2}}; const StringPiece input = WS "株式会社"; mock->SetEncodeResult(input, result); sp.SetModel(std::move(mock)); @@ -284,13 +296,95 @@ TEST(SentencepieceProcessorTest, EncodeTest) { } } +TEST(SentencepieceProcessorTest, NBestEncodeTest) { + const std::string kInput = WS "ABC" WS "DEF"; + SentencePieceProcessor sp; + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + + auto mock = MakeUnique<MockModel>(); + + const NBestEncodeResult result = { + {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0}, + {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.9}}; + + mock->SetNBestEncodeResult(kInput, result); + sp.SetModel(std::move(mock)); + sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec)); + + std::vector<std::vector<std::string>> output; + sp.NBestEncode("ABC DEF", 2, &output); + EXPECT_EQ(2, output.size()); + EXPECT_EQ(GetSpVec(result[0].first), output[0]); + EXPECT_EQ(GetSpVec(result[1].first), output[1]); + + NBestSentencePieceText spt; + sp.NBestEncode("ABC DEF", 2, &spt); + EXPECT_EQ(2, spt.nbests_size()); + EXPECT_EQ(4, spt.nbests(0).pieces_size()); + EXPECT_EQ(4, spt.nbests(1).pieces_size()); + EXPECT_NEAR(result[0].second, spt.nbests(0).score(), 0.001); + EXPECT_NEAR(result[1].second, spt.nbests(1).score(), 0.001); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(result[0].first[i].first, spt.nbests(0).pieces(i).piece()); + EXPECT_EQ(result[1].first[i].first, spt.nbests(1).pieces(i).piece()); + } +} + +TEST(SentencepieceProcessorTest, SampleEncodeTest) { + const std::string kInput = WS "ABC" WS "DEF"; + SentencePieceProcessor sp; + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + + auto mock = MakeUnique<MockModel>(); + + const EncodeResult result = { + {WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}; + const NBestEncodeResult nbest_result = { + {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0}, + {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.1}}; + + mock->SetNBestEncodeResult(kInput, nbest_result); + mock->SetEncodeResult(kInput, result); + sp.SetModel(std::move(mock)); + sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec)); + + std::vector<std::string> output; + sp.SampleEncode("ABC DEF", -1, 0.5, &output); + EXPECT_EQ(4, output.size()); + EXPECT_EQ(GetSpVec(result), output); + + SentencePieceText spt; + sp.SampleEncode("ABC DEF", -1, 0.5, &spt); + EXPECT_EQ(4, spt.pieces_size()); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(result[i].first, spt.pieces(i).piece()); + EXPECT_EQ(result[i].second, spt.pieces(i).id()); + } + + std::vector<int> freq(2, 0); + for (int i = 0; i < 5000; ++i) { + sp.SampleEncode("ABC DEF", 20, 0.5, &output); + EXPECT_EQ(4, output.size()); + if (GetSpVec(nbest_result[0].first) == output) + freq[0]++; + else if (GetSpVec(nbest_result[1].first) == output) + freq[1]++; + else + LOG(FATAL) << "Invalid result."; + } + + const float expected_prob = + std::exp(0.5 * 1.0) / (std::exp(0.5 * 1.0) + std::exp(0.5 * 0.1)); + const float prob = 1.0 * freq[0] / (freq[0] + freq[1]); + EXPECT_NEAR(prob, expected_prob, 0.05); +} + TEST(SentencepieceProcessorTest, DecodeTest) { class DecodeMockModel : public ModelInterface { public: - std::vector<std::pair<StringPiece, int>> Encode( - StringPiece normalized) const override { - return {}; - } + EncodeResult Encode(StringPiece normalized) const override { return {}; } int GetPieceSize() const override { return 7; } |