Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-02-28 07:14:52 +0300
committerTaku Kudo <taku@google.com>2018-02-28 07:14:52 +0300
commitc6a1a196651789ba4c0334dbf41d5885b3334b2f (patch)
treea6e2b3a0e7a71b9d4d21e9f8800dc9b4b3cea94b /src/sentencepiece_processor_test.cc
parentab766cbdaac1332776ae2c457fed9380f500159b (diff)
Add Sample/NBestEncode
Diffstat (limited to 'src/sentencepiece_processor_test.cc')
-rw-r--r--src/sentencepiece_processor_test.cc136
1 files changed, 115 insertions, 21 deletions
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 80ebead..dce8160 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -31,18 +31,32 @@ using port::MakeUnique;
class MockModel : public ModelInterface {
public:
- void SetEncodeResult(StringPiece input,
- const std::vector<std::pair<StringPiece, int>> &output) {
+ void SetEncodeResult(StringPiece input, const EncodeResult &output) {
input_ = input;
output_ = output;
}
- std::vector<std::pair<StringPiece, int>> Encode(
- StringPiece normalized) const {
+ void SetNBestEncodeResult(StringPiece input,
+ const NBestEncodeResult &output) {
+ input_ = input;
+ nbest_output_ = output;
+ }
+
+ EncodeResult Encode(StringPiece normalized) const {
+ EXPECT_EQ(normalized, input_);
+ return output_;
+ }
+
+ EncodeResult SampleEncode(StringPiece normalized, float alpha) const {
EXPECT_EQ(normalized, input_);
return output_;
}
+ NBestEncodeResult NBestEncode(StringPiece normalized, int nbest_size) const {
+ EXPECT_EQ(normalized, input_);
+ return nbest_output_;
+ }
+
bool IsControl(int id) const { return id == 1 || id == 2; }
bool IsUnknown(int id) const { return id == 0; }
@@ -57,11 +71,11 @@ class MockModel : public ModelInterface {
private:
StringPiece input_;
- std::vector<std::pair<StringPiece, int>> output_;
+ EncodeResult output_;
+ NBestEncodeResult nbest_output_;
};
-std::vector<std::string> GetSpVec(
- const std::vector<std::pair<StringPiece, int>> &pieces) {
+std::vector<std::string> GetSpVec(const EncodeResult &pieces) {
std::vector<std::string> sps;
for (const auto &p : pieces) {
sps.emplace_back(p.first.to_string());
@@ -90,7 +104,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
+ const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
@@ -132,9 +146,9 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
+ const EncodeResult result = {
{WS "ABC", 3}, {WS "D", 4}, {"E", 0}, {"F", 0}, {"</s>", 2}};
- const std::vector<std::pair<StringPiece, int>> expected = {
+ const EncodeResult expected = {
{WS "ABC", 3}, {WS "D", 4}, {"EF", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
@@ -176,7 +190,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
// ModelInterface::Encode() returns shorter results.
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {{WS "ABC", 3}};
+ const EncodeResult result = {{WS "ABC", 3}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
@@ -189,7 +203,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
// ModelInterface::Encode() returns longer results.
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
+ const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
@@ -203,7 +217,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
// ModelInterface::Encode() returns an empty piece.
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
+ const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"", 5}, {"F", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
@@ -216,8 +230,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
// Halfwidth to Fullwidith katakana normalization.
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
- {WS "グー", 3}, {"グル", 4}, {"</s>", 2}};
+ const EncodeResult result = {{WS "グー", 3}, {"グル", 4}, {"</s>", 2}};
const StringPiece input = WS "グーグル";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
@@ -251,8 +264,7 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
// One to many normalization.
{
auto mock = MakeUnique<MockModel>();
- const std::vector<std::pair<StringPiece, int>> result = {
- {WS "株式", 3}, {"会社", 4}, {"</s>", 2}};
+ const EncodeResult result = {{WS "株式", 3}, {"会社", 4}, {"</s>", 2}};
const StringPiece input = WS "株式会社";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
@@ -284,13 +296,95 @@ TEST(SentencepieceProcessorTest, EncodeTest) {
}
}
+TEST(SentencepieceProcessorTest, NBestEncodeTest) {
+ const std::string kInput = WS "ABC" WS "DEF";
+ SentencePieceProcessor sp;
+
+ const auto normalization_spec = MakeDefaultNormalizerSpec();
+
+ auto mock = MakeUnique<MockModel>();
+
+ const NBestEncodeResult result = {
+ {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0},
+ {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.9}};
+
+ mock->SetNBestEncodeResult(kInput, result);
+ sp.SetModel(std::move(mock));
+ sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
+
+ std::vector<std::vector<std::string>> output;
+ sp.NBestEncode("ABC DEF", 2, &output);
+ EXPECT_EQ(2, output.size());
+ EXPECT_EQ(GetSpVec(result[0].first), output[0]);
+ EXPECT_EQ(GetSpVec(result[1].first), output[1]);
+
+ NBestSentencePieceText spt;
+ sp.NBestEncode("ABC DEF", 2, &spt);
+ EXPECT_EQ(2, spt.nbests_size());
+ EXPECT_EQ(4, spt.nbests(0).pieces_size());
+ EXPECT_EQ(4, spt.nbests(1).pieces_size());
+ EXPECT_NEAR(result[0].second, spt.nbests(0).score(), 0.001);
+ EXPECT_NEAR(result[1].second, spt.nbests(1).score(), 0.001);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(result[0].first[i].first, spt.nbests(0).pieces(i).piece());
+ EXPECT_EQ(result[1].first[i].first, spt.nbests(1).pieces(i).piece());
+ }
+}
+
+TEST(SentencepieceProcessorTest, SampleEncodeTest) {
+ const std::string kInput = WS "ABC" WS "DEF";
+ SentencePieceProcessor sp;
+
+ const auto normalization_spec = MakeDefaultNormalizerSpec();
+
+ auto mock = MakeUnique<MockModel>();
+
+ const EncodeResult result = {
+ {WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
+ const NBestEncodeResult nbest_result = {
+ {{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, 1.0},
+ {{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, 0.1}};
+
+ mock->SetNBestEncodeResult(kInput, nbest_result);
+ mock->SetEncodeResult(kInput, result);
+ sp.SetModel(std::move(mock));
+ sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
+
+ std::vector<std::string> output;
+ sp.SampleEncode("ABC DEF", -1, 0.5, &output);
+ EXPECT_EQ(4, output.size());
+ EXPECT_EQ(GetSpVec(result), output);
+
+ SentencePieceText spt;
+ sp.SampleEncode("ABC DEF", -1, 0.5, &spt);
+ EXPECT_EQ(4, spt.pieces_size());
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(result[i].first, spt.pieces(i).piece());
+ EXPECT_EQ(result[i].second, spt.pieces(i).id());
+ }
+
+ std::vector<int> freq(2, 0);
+ for (int i = 0; i < 5000; ++i) {
+ sp.SampleEncode("ABC DEF", 20, 0.5, &output);
+ EXPECT_EQ(4, output.size());
+ if (GetSpVec(nbest_result[0].first) == output)
+ freq[0]++;
+ else if (GetSpVec(nbest_result[1].first) == output)
+ freq[1]++;
+ else
+ LOG(FATAL) << "Invalid result.";
+ }
+
+ const float expected_prob =
+ std::exp(0.5 * 1.0) / (std::exp(0.5 * 1.0) + std::exp(0.5 * 0.1));
+ const float prob = 1.0 * freq[0] / (freq[0] + freq[1]);
+ EXPECT_NEAR(prob, expected_prob, 0.05);
+}
+
TEST(SentencepieceProcessorTest, DecodeTest) {
class DecodeMockModel : public ModelInterface {
public:
- std::vector<std::pair<StringPiece, int>> Encode(
- StringPiece normalized) const override {
- return {};
- }
+ EncodeResult Encode(StringPiece normalized) const override { return {}; }
int GetPieceSize() const override { return 7; }