From f7bc3dbfb6b9afb4b7323e01b200b83291ee9b34 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Tue, 13 Oct 2020 13:02:56 +0900 Subject: merges internal changes to github --- src/builder.cc | 8 ++++---- src/char_model_trainer.cc | 5 +++-- src/sentencepiece_model.proto | 8 -------- src/sentencepiece_processor.cc | 8 ++++++-- src/sentencepiece_processor_test.cc | 22 ++++++++++++++++++++++ src/trainer_interface.cc | 3 +-- src/unigram_model_trainer.cc | 4 ++++ src/word_model_trainer.cc | 3 ++- 8 files changed, 42 insertions(+), 19 deletions(-) diff --git a/src/builder.cc b/src/builder.cc index d9442d3..2c83645 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -476,7 +476,7 @@ util::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) { // static util::Status Builder::LoadCharsMap(absl::string_view filename, CharsMap *chars_map) { - LOG(INFO) << "Loading maping file: " << filename.data(); + LOG(INFO) << "Loading mapping file: " << filename.data(); CHECK_OR_RETURN(chars_map); auto input = filesystem::NewReadableFile(filename); @@ -487,16 +487,16 @@ util::Status Builder::LoadCharsMap(absl::string_view filename, chars_map->clear(); while (input->ReadLine(&line)) { std::vector fields = - absl::StrSplit(line, "\t", absl::AllowEmpty()); + absl::StrSplit(line, '\t', absl::AllowEmpty()); CHECK_GE(fields.size(), 1); if (fields.size() == 1) fields.push_back(""); // Deletion rule. std::vector src, trg; - for (auto s : absl::StrSplit(fields[0], " ")) { + for (auto s : absl::StrSplit(fields[0], ' ')) { if (s.empty()) continue; absl::ConsumePrefix(&s, "U+"); src.push_back(string_util::HexToInt(s)); } - for (auto s : absl::StrSplit(fields[1], " ")) { + for (auto s : absl::StrSplit(fields[1], ' ')) { if (s.empty()) continue; absl::ConsumePrefix(&s, "U+"); trg.push_back(string_util::HexToInt(s)); diff --git a/src/char_model_trainer.cc b/src/char_model_trainer.cc index b758a8c..f438d78 100644 --- a/src/char_model_trainer.cc +++ b/src/char_model_trainer.cc @@ -45,8 +45,9 @@ util::Status Trainer::Train() { final_pieces_.size() == static_cast(vocab_size)) { break; } - final_pieces_.emplace_back(string_util::UnicodeCharToUTF8(it.first), - std::log(static_cast(it.second)) - logsum); + final_pieces_.emplace_back( + string_util::UnicodeCharToUTF8(it.first), + std::log(static_cast(it.second)) - logsum); } if (trainer_spec_.use_all_vocab()) { diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index fe7bef7..4128d6c 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -19,9 +19,6 @@ option optimize_for = LITE_RUNTIME; package sentencepiece; -// BEGIN GOOGLE-INTERNAL -// LINT.IfChange -// END GOOGLE-INTERNAL // TrainerSpec encodes a various parameters for SentencePiece training. message TrainerSpec { /////////////////////////////////////////////////////////////////// @@ -249,11 +246,6 @@ message NormalizerSpec { // are open to third-party extensions. extensions 200 to max; } -// BEGIN GOOGLE-INTERNAL -// LINT.ThenChange( -// //depot/google3/third_party/sentencepiece/src/spm_train_main.cc, -// //depot/google3/third_party/sentencepiece/src/spec_parser.h) -// END GOOGLE-INTERNAL // Proto to store samples for self-testing. message SelfTestData { diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1e87a80..df053fd 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -362,7 +362,7 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText( CHECK_LE_OR_RETURN(orig_end, input.size()); CHECK_LE_OR_RETURN(orig_begin, orig_end); const auto surface = - absl::ClippedSubstr(input.data(), orig_begin, orig_end - orig_begin); + absl::ClippedSubstr(input, orig_begin, orig_end - orig_begin); if (is_unk && model_->ByteFallbackEnabled()) { // Decomposes an unknown piece into UTF-8 bytes @@ -520,7 +520,11 @@ util::Status SentencePieceProcessor::Decode( } } - if (is_bos_ws) { + if (is_bos_ws && + (!model_proto_ || + (model_proto_ && + (model_proto_->normalizer_spec().add_dummy_prefix() || + model_proto_->normalizer_spec().remove_extra_whitespaces())))) { // Consume if the current position is bos and // piece starts with kSpaceSymbol. absl::ConsumePrefix(&piece, kSpaceSymbol); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index ef54071..571dde4 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -685,6 +685,28 @@ TEST(SentencepieceProcessorTest, DecodeTest) { EXPECT_EQ("ABC DEFG HI", spt.text()); EXPECT_EQ(8, spt.pieces_size()); } + + { + SentencePieceProcessor sp; + auto proto = absl::make_unique(); + proto->mutable_trainer_spec()->set_unk_surface(""); + proto->mutable_normalizer_spec()->set_add_dummy_prefix(false); + proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false); + sp.Load(std::move(proto)).IgnoreError(); + + auto mock = absl::make_unique(); + sp.SetModel(std::move(mock)); + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + sp.SetNormalizer( + absl::make_unique(normalization_spec)); + + SentencePieceText spt; + + EXPECT_TRUE(sp.Decode(input, &spt).ok()); + EXPECT_EQ(" ABC DEFG HI", spt.text()); + EXPECT_EQ(8, spt.pieces_size()); + } } TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index d340af2..0ea71d3 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "trainer_interface.h" - #include #include #include @@ -34,6 +32,7 @@ #include "third_party/absl/strings/str_format.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" +#include "trainer_interface.h" #include "unicode_script.h" #include "util.h" diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index 5f26771..e5dc8c0 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -121,7 +121,11 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() const { } } + CHECK_LE(array.size(), + static_cast(std::numeric_limits::max())) + << "Input corpus too large, try with train_extremely_large_corpus=true"; const node_int_type n = array.size(); + std::vector SA(n); // suffix array std::vector L(n); // left boundaries of internal node std::vector R(n); // right boundaries of internal node diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc index ae274d9..0b8b062 100644 --- a/src/word_model_trainer.cc +++ b/src/word_model_trainer.cc @@ -58,7 +58,8 @@ util::Status Trainer::Train() { final_pieces_.size() == static_cast(vocab_size)) { break; } - final_pieces_.emplace_back(it.first, std::log(static_cast(it.second)) - logsum); + final_pieces_.emplace_back( + it.first, std::log(static_cast(it.second)) - logsum); } if (trainer_spec_.use_all_vocab()) { -- cgit v1.2.3