diff options
author | Taku Kudo <taku@google.com> | 2020-10-13 07:02:56 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-10-13 07:02:56 +0300 |
commit | f7bc3dbfb6b9afb4b7323e01b200b83291ee9b34 (patch) | |
tree | 7bea5bb344b3dc0b4136a2596ed79165f23f012b | |
parent | 8a08ae5e08350420b1b6c948733d9fbd5c0ed2c2 (diff) |
merges internal changes to github
-rw-r--r-- | src/builder.cc | 8 | ||||
-rw-r--r-- | src/char_model_trainer.cc | 5 | ||||
-rw-r--r-- | src/sentencepiece_model.proto | 8 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 8 | ||||
-rw-r--r-- | src/sentencepiece_processor_test.cc | 22 | ||||
-rw-r--r-- | src/trainer_interface.cc | 3 | ||||
-rw-r--r-- | src/unigram_model_trainer.cc | 4 | ||||
-rw-r--r-- | src/word_model_trainer.cc | 3 |
8 files changed, 42 insertions, 19 deletions
diff --git a/src/builder.cc b/src/builder.cc index d9442d3..2c83645 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -476,7 +476,7 @@ util::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) { // static util::Status Builder::LoadCharsMap(absl::string_view filename, CharsMap *chars_map) { - LOG(INFO) << "Loading maping file: " << filename.data(); + LOG(INFO) << "Loading mapping file: " << filename.data(); CHECK_OR_RETURN(chars_map); auto input = filesystem::NewReadableFile(filename); @@ -487,16 +487,16 @@ util::Status Builder::LoadCharsMap(absl::string_view filename, chars_map->clear(); while (input->ReadLine(&line)) { std::vector<std::string> fields = - absl::StrSplit(line, "\t", absl::AllowEmpty()); + absl::StrSplit(line, '\t', absl::AllowEmpty()); CHECK_GE(fields.size(), 1); if (fields.size() == 1) fields.push_back(""); // Deletion rule. std::vector<char32> src, trg; - for (auto s : absl::StrSplit(fields[0], " ")) { + for (auto s : absl::StrSplit(fields[0], ' ')) { if (s.empty()) continue; absl::ConsumePrefix(&s, "U+"); src.push_back(string_util::HexToInt<char32>(s)); } - for (auto s : absl::StrSplit(fields[1], " ")) { + for (auto s : absl::StrSplit(fields[1], ' ')) { if (s.empty()) continue; absl::ConsumePrefix(&s, "U+"); trg.push_back(string_util::HexToInt<char32>(s)); diff --git a/src/char_model_trainer.cc b/src/char_model_trainer.cc index b758a8c..f438d78 100644 --- a/src/char_model_trainer.cc +++ b/src/char_model_trainer.cc @@ -45,8 +45,9 @@ util::Status Trainer::Train() { final_pieces_.size() == static_cast<size_t>(vocab_size)) { break; } - final_pieces_.emplace_back(string_util::UnicodeCharToUTF8(it.first), - std::log(static_cast<float>(it.second)) - logsum); + final_pieces_.emplace_back( + string_util::UnicodeCharToUTF8(it.first), + std::log(static_cast<float>(it.second)) - logsum); } if (trainer_spec_.use_all_vocab()) { diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index fe7bef7..4128d6c 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -19,9 +19,6 @@ option optimize_for = LITE_RUNTIME; package sentencepiece; -// BEGIN GOOGLE-INTERNAL -// LINT.IfChange -// END GOOGLE-INTERNAL // TrainerSpec encodes a various parameters for SentencePiece training. message TrainerSpec { /////////////////////////////////////////////////////////////////// @@ -249,11 +246,6 @@ message NormalizerSpec { // are open to third-party extensions. extensions 200 to max; } -// BEGIN GOOGLE-INTERNAL -// LINT.ThenChange( -// //depot/google3/third_party/sentencepiece/src/spm_train_main.cc, -// //depot/google3/third_party/sentencepiece/src/spec_parser.h) -// END GOOGLE-INTERNAL // Proto to store samples for self-testing. message SelfTestData { diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1e87a80..df053fd 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -362,7 +362,7 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText( CHECK_LE_OR_RETURN(orig_end, input.size()); CHECK_LE_OR_RETURN(orig_begin, orig_end); const auto surface = - absl::ClippedSubstr(input.data(), orig_begin, orig_end - orig_begin); + absl::ClippedSubstr(input, orig_begin, orig_end - orig_begin); if (is_unk && model_->ByteFallbackEnabled()) { // Decomposes an unknown piece into UTF-8 bytes @@ -520,7 +520,11 @@ util::Status SentencePieceProcessor::Decode( } } - if (is_bos_ws) { + if (is_bos_ws && + (!model_proto_ || + (model_proto_ && + (model_proto_->normalizer_spec().add_dummy_prefix() || + model_proto_->normalizer_spec().remove_extra_whitespaces())))) { // Consume if the current position is bos and // piece starts with kSpaceSymbol. absl::ConsumePrefix(&piece, kSpaceSymbol); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index ef54071..571dde4 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -685,6 +685,28 @@ TEST(SentencepieceProcessorTest, DecodeTest) { EXPECT_EQ("ABC<UNK> DEFG HI", spt.text()); EXPECT_EQ(8, spt.pieces_size()); } + + { + SentencePieceProcessor sp; + auto proto = absl::make_unique<ModelProto>(); + proto->mutable_trainer_spec()->set_unk_surface(""); + proto->mutable_normalizer_spec()->set_add_dummy_prefix(false); + proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false); + sp.Load(std::move(proto)).IgnoreError(); + + auto mock = absl::make_unique<DecodeMockModel>(); + sp.SetModel(std::move(mock)); + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + sp.SetNormalizer( + absl::make_unique<normalizer::Normalizer>(normalization_spec)); + + SentencePieceText spt; + + EXPECT_TRUE(sp.Decode(input, &spt).ok()); + EXPECT_EQ(" ABC DEFG HI", spt.text()); + EXPECT_EQ(8, spt.pieces_size()); + } } TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index d340af2..0ea71d3 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "trainer_interface.h" - #include <cstdlib> #include <memory> #include <set> @@ -34,6 +32,7 @@ #include "third_party/absl/strings/str_format.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" +#include "trainer_interface.h" #include "unicode_script.h" #include "util.h" diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index 5f26771..e5dc8c0 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -121,7 +121,11 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() const { } } + CHECK_LE(array.size(), + static_cast<size_t>(std::numeric_limits<node_int_type>::max())) + << "Input corpus too large, try with train_extremely_large_corpus=true"; const node_int_type n = array.size(); + std::vector<node_int_type> SA(n); // suffix array std::vector<node_int_type> L(n); // left boundaries of internal node std::vector<node_int_type> R(n); // right boundaries of internal node diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc index ae274d9..0b8b062 100644 --- a/src/word_model_trainer.cc +++ b/src/word_model_trainer.cc @@ -58,7 +58,8 @@ util::Status Trainer::Train() { final_pieces_.size() == static_cast<size_t>(vocab_size)) { break; } - final_pieces_.emplace_back(it.first, std::log(static_cast<float>(it.second)) - logsum); + final_pieces_.emplace_back( + it.first, std::log(static_cast<float>(it.second)) - logsum); } if (trainer_spec_.use_all_vocab()) { |