// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #include #include "filesystem.h" #include "testharness.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_format.h" #include "trainer_interface.h" #include "util.h" namespace sentencepiece { // Space symbol #define WS "\xe2\x96\x81" // Converts the 1 unicode string to the code point. static char32 ToChar32(absl::string_view str) { string_util::UnicodeText utext = string_util::UTF8ToUnicodeText(str); return !utext.empty() ? *utext.begin() : 0; } TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; trainer_spec.set_model_prefix("model"); trainer_spec.add_input("input"); // Calls the default method for better coverage. TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.Train().ok()); auto IsValid = [&trainer_spec, &normalizer_spec, &denormalizer_spec](const std::string &str) { TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); const string_util::UnicodeText text = string_util::UTF8ToUnicodeText(str); return trainer.IsValidSentencePiece(text); }; EXPECT_FALSE(trainer.IsValidSentencePiece({0x01, 0x00, 0x01})); EXPECT_FALSE(trainer.IsValidSentencePiece({0x01, 0x00})); EXPECT_FALSE(trainer.IsValidSentencePiece({0x00, 0x01})); EXPECT_FALSE(trainer.IsValidSentencePiece({0x00})); // Default trainer spec. EXPECT_FALSE(IsValid("")); EXPECT_FALSE(IsValid("12345678912345678")); // too long EXPECT_TRUE(IsValid("a")); EXPECT_TRUE(IsValid(WS)); EXPECT_TRUE(IsValid(WS "a")); EXPECT_FALSE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); EXPECT_FALSE(IsValid("a" WS "b")); EXPECT_FALSE(IsValid("a" WS "b" WS)); EXPECT_TRUE(IsValid("あいう")); EXPECT_TRUE(IsValid("グーグル")); // "ー" is a part of Katakana EXPECT_TRUE(IsValid("食べる")); EXPECT_FALSE(IsValid("漢字ABC")); // mixed CJK scripts EXPECT_FALSE(IsValid("F1")); EXPECT_FALSE(IsValid("1F")); EXPECT_FALSE(IsValid("1A2")); EXPECT_TRUE(IsValid("$10")); // $ and 1 are both "common" script. EXPECT_FALSE(IsValid("$ABC")); EXPECT_FALSE(IsValid("ab\tbc")); // "\t" is UPP boundary. EXPECT_FALSE(IsValid("ab cd")); EXPECT_FALSE(IsValid("\0\0")); EXPECT_FALSE(IsValid("\0")); trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); EXPECT_TRUE(IsValid(WS "a")); EXPECT_FALSE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); EXPECT_TRUE(IsValid("a" WS "b")); // "a b" is a valid piece. EXPECT_TRUE(IsValid(WS "a" WS "b")); EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c")); EXPECT_FALSE(IsValid("a" WS "b" WS)); trainer_spec.set_split_by_unicode_script(false); EXPECT_TRUE(IsValid("あいう")); EXPECT_TRUE(IsValid("グーグル")); EXPECT_TRUE(IsValid("食べる")); EXPECT_TRUE(IsValid("漢字ABC")); EXPECT_TRUE(IsValid("F1")); EXPECT_TRUE(IsValid("$10")); EXPECT_TRUE(IsValid("$ABC")); trainer_spec.set_max_sentencepiece_length(4); EXPECT_TRUE(IsValid("1234")); EXPECT_FALSE(IsValid("12345")); trainer_spec.set_max_sentencepiece_length(10); trainer_spec.set_split_by_unicode_script(true); trainer_spec.set_split_by_number(false); EXPECT_TRUE(IsValid("F1")); EXPECT_TRUE(IsValid("11")); EXPECT_TRUE(IsValid("1F")); EXPECT_TRUE(IsValid("ABC")); EXPECT_TRUE(IsValid("1A2")); EXPECT_TRUE(IsValid("1a234abc")); EXPECT_FALSE(IsValid("9Aあ")); EXPECT_TRUE(IsValid("9あい0A")); trainer_spec.set_split_by_whitespace(true); trainer_spec.set_treat_whitespace_as_suffix(true); EXPECT_TRUE(IsValid(WS)); EXPECT_FALSE(IsValid(WS "a")); EXPECT_TRUE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); EXPECT_FALSE(IsValid("a" WS "b")); EXPECT_FALSE(IsValid(WS "a" WS "b")); EXPECT_FALSE(IsValid("a" WS "b" WS)); trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); EXPECT_FALSE(IsValid(WS "a")); EXPECT_TRUE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); EXPECT_TRUE(IsValid("a" WS "b")); EXPECT_FALSE(IsValid(WS "a" WS "b")); EXPECT_TRUE(IsValid("a" WS "b" WS)); trainer_spec.set_split_digits(false); EXPECT_TRUE(IsValid("1")); EXPECT_TRUE(IsValid("59")); EXPECT_TRUE(IsValid("2007")); EXPECT_TRUE(IsValid("x1")); EXPECT_TRUE(IsValid("2x")); trainer_spec.set_split_digits(true); EXPECT_TRUE(IsValid("1")); EXPECT_FALSE(IsValid("59")); EXPECT_FALSE(IsValid("2007")); EXPECT_FALSE(IsValid("x1")); EXPECT_FALSE(IsValid("2x")); } TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) { TrainerSpec base_trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; base_trainer_spec.set_model_prefix("model"); base_trainer_spec.add_input("input"); auto trainer_spec = base_trainer_spec; // Check default values. EXPECT_EQ(0, trainer_spec.unk_id()); EXPECT_EQ(1, trainer_spec.bos_id()); EXPECT_EQ(2, trainer_spec.eos_id()); EXPECT_EQ(-1, trainer_spec.pad_id()); { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(1); trainer_spec.set_eos_id(2); trainer_spec.set_pad_id(3); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_EQ(4, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); EXPECT_EQ("", trainer.meta_pieces_[1].first); EXPECT_EQ("", trainer.meta_pieces_[2].first); EXPECT_EQ("", trainer.meta_pieces_[3].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(3); trainer_spec.set_eos_id(2); trainer_spec.set_pad_id(1); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_EQ(4, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); EXPECT_EQ("", trainer.meta_pieces_[1].first); EXPECT_EQ("", trainer.meta_pieces_[2].first); EXPECT_EQ("", trainer.meta_pieces_[3].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(-1); trainer_spec.set_eos_id(1); trainer_spec.set_pad_id(-1); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_EQ(2, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); EXPECT_EQ("", trainer.meta_pieces_[1].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(-1); trainer_spec.set_eos_id(-1); trainer_spec.set_pad_id(-1); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_EQ(1, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(1); trainer_spec.set_eos_id(2); trainer_spec.set_pad_id(-1); trainer_spec.add_control_symbols(""); trainer_spec.add_control_symbols(""); trainer_spec.add_user_defined_symbols(""); trainer_spec.add_user_defined_symbols(""); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_EQ(7, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); EXPECT_EQ("", trainer.meta_pieces_[1].first); EXPECT_EQ("", trainer.meta_pieces_[2].first); EXPECT_EQ("", trainer.meta_pieces_[3].first); EXPECT_EQ("", trainer.meta_pieces_[4].first); EXPECT_EQ("", trainer.meta_pieces_[5].first); EXPECT_EQ("", trainer.meta_pieces_[6].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(-1); trainer_spec.set_eos_id(2); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.status().ok()); } { // UNK is not defined. auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(-1); trainer_spec.set_bos_id(0); trainer_spec.set_eos_id(1); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } { // UNK is out-of-range. auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(640000); trainer_spec.set_bos_id(0); trainer_spec.set_eos_id(1); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_vocab_size(32000); trainer_spec.set_unk_id(32000 - 1); trainer_spec.set_bos_id(32000 - 100); trainer_spec.set_eos_id(32000 - 200); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.status().ok()); } { // Cannot assign as control symbol. auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(1); trainer_spec.set_eos_id(2); trainer_spec.add_control_symbols(""); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } { // Dup. auto trainer_spec = base_trainer_spec; trainer_spec.add_control_symbols(""); trainer_spec.add_control_symbols(""); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_id(0); trainer_spec.set_bos_id(10); trainer_spec.set_eos_id(20); trainer_spec.set_pad_id(30); // , are treated as USER_DEFIEND, // is CONTROL. trainer_spec.add_user_defined_symbols(""); trainer_spec.add_user_defined_symbols(""); trainer_spec.add_user_defined_symbols("foo"); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.status().ok()); EXPECT_EQ(5, trainer.meta_pieces_.size()); EXPECT_EQ("", trainer.meta_pieces_[0].first); EXPECT_EQ("", trainer.meta_pieces_[10].first); EXPECT_EQ("", trainer.meta_pieces_[20].first); EXPECT_EQ("", trainer.meta_pieces_[30].first); EXPECT_EQ("foo", trainer.meta_pieces_[1].first); EXPECT_EQ(ModelProto::SentencePiece::UNKNOWN, trainer.meta_pieces_[0].second); EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED, trainer.meta_pieces_[10].second); EXPECT_EQ(ModelProto::SentencePiece::CONTROL, trainer.meta_pieces_[20].second); EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED, trainer.meta_pieces_[30].second); EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED, trainer.meta_pieces_[1].second); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_piece("__UNK__"); trainer_spec.set_bos_piece("__BOS__"); trainer_spec.set_eos_piece("__EOS__"); trainer_spec.set_pad_piece("__PAD__"); trainer_spec.set_pad_id(3); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.status().ok()); EXPECT_EQ("__UNK__", trainer.meta_pieces_[0].first); EXPECT_EQ("__BOS__", trainer.meta_pieces_[1].first); EXPECT_EQ("__EOS__", trainer.meta_pieces_[2].first); EXPECT_EQ("__PAD__", trainer.meta_pieces_[3].first); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_piece("__UNK__"); trainer_spec.set_bos_piece("__UNK__"); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } { auto trainer_spec = base_trainer_spec; trainer_spec.set_unk_piece(""); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_FALSE(trainer.status().ok()); } } TEST(TrainerInterfaceTest, BytePiecesTest) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; trainer_spec.set_model_prefix("model"); trainer_spec.add_input("input"); trainer_spec.add_control_symbols(""); trainer_spec.add_control_symbols(""); trainer_spec.add_user_defined_symbols(""); trainer_spec.add_user_defined_symbols(""); trainer_spec.set_byte_fallback(true); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_TRUE(trainer.status().ok()); // Byte pieces come after control symbols and user-defined symbols. for (int i = 0; i < 256; ++i) { const auto &piece = trainer.meta_pieces_[i + 7]; EXPECT_EQ(absl::StrFormat("<0x%02X>", i), piece.first); EXPECT_EQ(ModelProto::SentencePiece::BYTE, piece.second); } } TEST(TrainerInterfaceTest, SerializeTest) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; trainer_spec.set_model_prefix("model"); trainer_spec.add_input("input"); EXPECT_TRUE(trainer_spec.hard_vocab_limit()); std::vector> final_pieces = { {"a", 0.1}, {"b", 0.2}, {"c", 0.3}}; { trainer_spec.set_vocab_size(10); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); trainer.final_pieces_ = final_pieces; ModelProto model_proto; EXPECT_FALSE(trainer.Serialize(&model_proto).ok()); } { trainer_spec.set_vocab_size(10); trainer_spec.set_hard_vocab_limit(false); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); trainer.final_pieces_ = final_pieces; ModelProto model_proto; EXPECT_TRUE(trainer.Serialize(&model_proto).ok()); EXPECT_EQ(6, model_proto.trainer_spec().vocab_size()); for (int i = 3; i < 6; ++i) { EXPECT_EQ(final_pieces[i - 3].first, model_proto.pieces(i).piece()); EXPECT_EQ(final_pieces[i - 3].second, model_proto.pieces(i).score()); } } { trainer_spec.set_vocab_size(10); trainer_spec.set_model_type(TrainerSpec::CHAR); trainer_spec.set_hard_vocab_limit(true); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); trainer.final_pieces_ = final_pieces; ModelProto model_proto; EXPECT_TRUE(trainer.Serialize(&model_proto).ok()); EXPECT_EQ(6, model_proto.trainer_spec().vocab_size()); for (int i = 3; i < 6; ++i) { EXPECT_EQ(final_pieces[i - 3].first, model_proto.pieces(i).piece()); EXPECT_EQ(final_pieces[i - 3].second, model_proto.pieces(i).score()); } } } TEST(TrainerInterfaceTest, CharactersTest) { const std::string input_file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input"); { auto output = filesystem::NewWritableFile(input_file); // Make a single line with 50 "a", 49 "あ", and 1 "b". std::string line; for (int i = 0; i < 100; i++) { if (i < 50) { line += "a"; } else if (i < 99) { line += "あ"; } else { line += "b"; } } line += "\n"; output->WriteLine(line); } TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; trainer_spec.add_input(input_file); trainer_spec.set_model_prefix("model"); trainer_spec.set_character_coverage(0.98); using E = absl::flat_hash_map; { TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_OK(trainer.LoadSentences()); // Because --character_coverage=0.98, "a" and "あ" are chosen, but "b" is // dropped. EXPECT_EQ(trainer.required_chars_, E({{ToChar32("a"), 50}, {ToChar32("あ"), 49}})); } { trainer_spec.set_required_chars("漢字"); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_OK(trainer.LoadSentences()); // 漢 and 字 do not occur in the line, but they are added. EXPECT_EQ(trainer.required_chars_, E({{ToChar32("a"), 50}, {ToChar32("あ"), 49}, {ToChar32("漢"), 0}, {ToChar32("字"), 0}})); } { trainer_spec.set_required_chars("aあ"); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_OK(trainer.LoadSentences()); // Adding characters that frequently occur do not change the result. EXPECT_EQ(trainer.required_chars_, E({{ToChar32("a"), 50}, {ToChar32("あ"), 49}})); } { trainer_spec.set_required_chars("b"); TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_OK(trainer.LoadSentences()); // "b" is added with the correct frequency. EXPECT_EQ( trainer.required_chars_, E({{ToChar32("a"), 50}, {ToChar32("あ"), 49}, {ToChar32("b"), 1}})); } } TEST(TrainerInterfaceTest, MultiFileSentenceIteratorTest) { std::vector files; std::vector expected; for (int i = 0; i < 10; ++i) { const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), absl::StrCat("input", i)); auto output = filesystem::NewWritableFile(file); int num_line = (rand() % 100) + 1; for (int n = 0; n < num_line; ++n) { const auto value = absl::StrCat(rand()); expected.emplace_back(value); output->WriteLine(value); } files.push_back(file); } std::vector results; MultiFileSentenceIterator it(files); for (; !it.done(); it.Next()) results.emplace_back(it.value()); EXPECT_OK(it.status()); EXPECT_EQ(expected, results); } TEST(TrainerInterfaceTest, MultiFileSentenceIteratorErrorTest) { std::vector files; for (int i = 0; i < 10; ++i) { const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), absl::StrCat("input_not_exist", i)); files.push_back(file); } MultiFileSentenceIterator it(files); EXPECT_TRUE(it.done()); // no files can be loaded. EXPECT_FALSE(it.status().ok()); } } // namespace sentencepiece