From d48247191a6d50e469ed1a4a36e877befffd1851 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Wed, 20 May 2020 13:45:49 +0900 Subject: 0.1.91 pre-release --- src/bpe_model.h | 4 ++++ src/builder.cc | 9 +-------- src/compile_charsmap_main.cc | 6 +++--- src/model_interface.h | 6 ++++++ src/sentencepiece_processor.cc | 20 +++++++++++--------- src/sentencepiece_processor.h | 12 ++++-------- src/sentencepiece_processor_test.cc | 4 ++++ src/sentencepiece_trainer.cc | 33 +++++++++++++++++++-------------- src/sentencepiece_trainer.h | 15 ++++++++------- src/sentencepiece_trainer_test.cc | 3 +-- src/trainer_interface.cc | 18 +++++++----------- src/trainer_interface.h | 14 +++++++------- src/unigram_model.cc | 4 ++-- src/unigram_model.h | 4 ++++ 14 files changed, 81 insertions(+), 71 deletions(-) (limited to 'src') diff --git a/src/bpe_model.h b/src/bpe_model.h index 243664f..c6e1abe 100644 --- a/src/bpe_model.h +++ b/src/bpe_model.h @@ -42,6 +42,10 @@ class Model : public ModelInterface { // When alpha <= 0.0, no sampling is performed. EncodeResult SampleEncode(absl::string_view normalized, float alpha) const override; + + bool IsSampleEncodeAvailable() const override { return true; } + + bool IsNBestEncodeAvailable() const override { return false; } }; } // namespace bpe } // namespace sentencepiece diff --git a/src/builder.cc b/src/builder.cc index 7e8ca98..d9442d3 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -54,14 +54,7 @@ Builder::Chars UnicodeNormalize(UNormalizationMode mode, const std::string utf8 = string_util::UnicodeTextToUTF8(input); CHECK(!utf8.empty()); - icu::UnicodeString ustr; - const size_t utf8_length = utf8.size(); - UChar *utf16 = ustr.getBuffer(utf8.size() + 1); - int32 utf16_length = 0; - icu::ErrorCode icuerrorcode; - u_strFromUTF8Lenient(utf16, ustr.getCapacity(), &utf16_length, utf8.data(), - utf8_length, icuerrorcode); - ustr.releaseBuffer(utf16_length); + icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(utf8.c_str()); UErrorCode status = U_ZERO_ERROR; icu::UnicodeString dst; diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc index 21f1ee8..e8fc072 100644 --- a/src/compile_charsmap_main.cc +++ b/src/compile_charsmap_main.cc @@ -25,7 +25,6 @@ #include "third_party/absl/strings/string_view.h" using sentencepiece::normalizer::Builder; -using util::Status; DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file"); @@ -157,8 +156,9 @@ struct BinaryBlob { int main(int argc, char **argv) { sentencepiece::flags::ParseCommandLineFlags(argv[0], &argc, &argv, true); - const std::vector< - std::pair>> + const std::vector>> kRuleList = {{"nfkc", Builder::BuildNFKCMap}, {"nmt_nfkc", Builder::BuildNmtNFKCMap}, {"nfkc_cf", Builder::BuildNFKC_CFMap}, diff --git a/src/model_interface.h b/src/model_interface.h index 98a4798..27dad99 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -106,6 +106,12 @@ class ModelInterface { return EncodeResult(); } + // Return true if SampleEncode returns a valid result. + virtual bool IsSampleEncodeAvailable() const { return false; } + + // Return true if NBestEncode returns a valid result. + virtual bool IsNBestEncodeAvailable() const { return false; } + // Returns the vocab id of `piece`. // Returns UNK(0) if `piece` is unknown virtual int PieceToId(absl::string_view piece) const; diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 4263a2f..a4dd575 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_processor.h" - #include #include #include @@ -24,6 +22,7 @@ #include "model_factory.h" #include "model_interface.h" #include "normalizer.h" +#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -446,6 +445,9 @@ util::Status SentencePieceProcessor::NBestEncode( std::vector norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); + CHECK_OR_RETURN(model_->IsNBestEncodeAvailable()) + << "NBestEncode is not available for the current model."; + const auto nbests = model_->NBestEncode(normalized, nbest_size); CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result."; @@ -470,7 +472,13 @@ util::Status SentencePieceProcessor::SampleEncode( std::vector norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - if (nbest_size == 1 || nbest_size == 0) { + if (!model_->IsNBestEncodeAvailable() || nbest_size < 0) { + CHECK_OR_RETURN(model_->IsSampleEncodeAvailable()) + << "SampleEncode is not available for the current model."; + const auto result = model_->SampleEncode(normalized, alpha); + RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, + result, spt)); + } else if (nbest_size == 1 || nbest_size == 0) { const auto result = model_->Encode(normalized); RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt)); @@ -487,11 +495,6 @@ util::Status SentencePieceProcessor::SampleEncode( std::discrete_distribution dist(probs.begin(), probs.end()); RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, nbests[dist(*mt)].first, spt)); - - } else if (nbest_size < 0) { - const auto result = model_->SampleEncode(normalized, alpha); - RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, - result, spt)); } return util::OkStatus(); @@ -828,6 +831,5 @@ util::Status SaveModelProto(absl::string_view filename, return util::OkStatus(); } - } // namespace io } // namespace sentencepiece diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 2b31cb1..019eddf 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -286,9 +286,8 @@ class SentencePieceProcessor { // // - BPE (--model_type=bpe): // `alpha` is the merge probability `p` in https://arxiv.org/abs/1910.13267 - // when alpha<=0, no sampling is performed but the best segmentation is - // returned. Nbest-based sampling is not supported so you need to specify - // nbest_size = 0 in BPE. + // Nbest-based sampling is not supported so nbest_size parameter is ignored in + // BPE. virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, std::vector *pieces) const; @@ -503,13 +502,10 @@ namespace io { // io::LoadModelProto("//path/spm.model", model_proto.get()); // SentencePieceProcessor sp; // CHECK_OK(sp.Load(std::move(model_proto))); -util::Status LoadModelProto(absl::string_view filename, - ModelProto *model_proto); +util::Status LoadModelProto(absl::string_view, ModelProto *model_proto); // Saves `model_proto` as `filename`. -util::Status SaveModelProto(absl::string_view filename, - const ModelProto &model_proto); - +util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto); } // namespace io #endif // SWIG } // namespace sentencepiece diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 3e00404..bceba2c 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -63,6 +63,10 @@ class MockModel : public ModelInterface { return nbest_output_; } + bool IsSampleEncodeAvailable() const override { return true; } + + bool IsNBestEncodeAvailable() const override { return true; } + bool IsControl(int id) const { return id == 1 || id == 2; } bool IsUnknown(int id) const { return id == 0; } diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 10c5b6f..e36aa9c 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_trainer.h" - #include #include @@ -22,6 +20,7 @@ #include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "normalizer.h" +#include "sentencepiece_trainer.h" #include "spec_parser.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_split.h" @@ -75,10 +74,15 @@ util::Status SentencePieceTrainer::Train( LOG(INFO) << "Starts training with : \n" << info; - trainer->SetSentenceIterator(sentence_iterator); - trainer->SetOutputSerializedModelProto(serialized_model_proto); + if (serialized_model_proto) { + ModelProto model_proto; + RETURN_IF_ERROR(trainer->Train(sentence_iterator, &model_proto)); + *serialized_model_proto = model_proto.SerializeAsString(); + } else { + RETURN_IF_ERROR(trainer->Train(sentence_iterator, nullptr)); + } - return trainer->Train(); + return util::OkStatus(); } // static @@ -100,7 +104,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( if (args.empty()) return util::OkStatus(); - std::map kwargs; + std::unordered_map kwargs; for (auto arg : absl::StrSplit(args, " ")) { absl::ConsumePrefix(&arg, "--"); std::string key, value; @@ -120,8 +124,9 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( // static util::Status SentencePieceTrainer::MergeSpecsFromArgs( - const std::map &kwargs, TrainerSpec *trainer_spec, - NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) { + const std::unordered_map &kwargs, + TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, + NormalizerSpec *denormalizer_spec) { CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null."; CHECK_OR_RETURN(normalizer_spec) << "`normalizer_spec` must not be null."; CHECK_OR_RETURN(denormalizer_spec) << "`denormalizer_spec` must not be null."; @@ -174,7 +179,7 @@ util::Status SentencePieceTrainer::Train(absl::string_view args, // static util::Status SentencePieceTrainer::Train( - const std::map &kwargs, + const std::unordered_map &kwargs, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; @@ -216,11 +221,11 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec( // static util::Status SentencePieceTrainer::PopulateModelTypeFromString( absl::string_view type, TrainerSpec *spec) { - static const std::map kModelTypeMap = { - {"unigram", TrainerSpec::UNIGRAM}, - {"bpe", TrainerSpec::BPE}, - {"word", TrainerSpec::WORD}, - {"char", TrainerSpec::CHAR}}; + static const std::unordered_map + kModelTypeMap = {{"unigram", TrainerSpec::UNIGRAM}, + {"bpe", TrainerSpec::BPE}, + {"word", TrainerSpec::WORD}, + {"char", TrainerSpec::CHAR}}; const auto it = kModelTypeMap.find(absl::AsciiStrToLower(type)); if (it != kModelTypeMap.end()) { spec->set_model_type(it->second); diff --git a/src/sentencepiece_trainer.h b/src/sentencepiece_trainer.h index 5782741..bb74ab9 100644 --- a/src/sentencepiece_trainer.h +++ b/src/sentencepiece_trainer.h @@ -15,8 +15,8 @@ #ifndef SENTENCEPIECE_TRAINER_H_ #define SENTENCEPIECE_TRAINER_H_ -#include #include +#include #include "sentencepiece_processor.h" @@ -84,9 +84,10 @@ class SentencePieceTrainer { // Trains SentencePiece model with mapin `kwargs`. // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...} - static util::Status Train(const std::map &kwargs, - SentenceIterator *sentence_iterator = nullptr, - std::string *serialized_model_proto = nullptr); + static util::Status Train( + const std::unordered_map &kwargs, + SentenceIterator *sentence_iterator = nullptr, + std::string *serialized_model_proto = nullptr); // Handy function to make a normalizer spec from the pre-compiled // normalization name. Do not use this method in production as it crashes @@ -96,12 +97,12 @@ class SentencePieceTrainer { // Populates necessary fields (precompiled_charmap) from // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`. static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec, - bool is_denomalizer = false); + bool is_denormalizer = false); // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the - // std::map in `kargs`. + // std::unordered_map in `kargs`. static util::Status MergeSpecsFromArgs( - const std::map &kwargs, + const std::unordered_map &kwargs, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec); diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index c95f686..b78b1d2 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_trainer.h" - #include "builtin_pb/sentencepiece_model.pb.h" #include "filesystem.h" +#include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/strings/str_cat.h" #include "util.h" diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 37f7003..5cdb300 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "trainer_interface.h" - #include #include #include @@ -34,6 +32,7 @@ #include "third_party/absl/strings/str_format.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" +#include "trainer_interface.h" #include "unicode_script.h" #include "util.h" @@ -50,7 +49,6 @@ const char TrainerInterface::kUPPBoundaryStr[] = "\t"; namespace { util::Status VerifySpec(const TrainerSpec &trainer_spec) { - // CHECK_OR_RETURN(!trainer_spec.model_prefix().empty()); CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0); if (trainer_spec.model_type() == TrainerSpec::UNIGRAM || @@ -313,10 +311,10 @@ util::Status TrainerInterface::LoadSentences() { (sentence_iterator_ == nullptr && !trainer_spec_.input().empty())) << "SentenceIterator and trainer_spec.input() must be exclusive."; - CHECK_OR_RETURN((serialized_model_proto_ != nullptr && - trainer_spec_.model_prefix().empty()) || - (serialized_model_proto_ == nullptr && - !trainer_spec_.model_prefix().empty())) + CHECK_OR_RETURN( + (output_model_proto_ != nullptr && + trainer_spec_.model_prefix().empty()) || + (output_model_proto_ == nullptr && !trainer_spec_.model_prefix().empty())) << "ModelProto and trainer_spec.model_prefix() must be exclusive."; const bool is_tsv = trainer_spec_.input_format() == "tsv"; @@ -647,10 +645,8 @@ util::Status TrainerInterface::SaveVocab(absl::string_view filename) const { } util::Status TrainerInterface::Save() const { - if (serialized_model_proto_) { - ModelProto model_proto; - RETURN_IF_ERROR(Serialize(&model_proto)); - *serialized_model_proto_ = model_proto.SerializeAsString(); + if (output_model_proto_) { + RETURN_IF_ERROR(Serialize(output_model_proto_)); } else { RETURN_IF_ERROR(SaveModel(trainer_spec_.model_prefix() + ".model")); RETURN_IF_ERROR(SaveVocab(trainer_spec_.model_prefix() + ".vocab")); diff --git a/src/trainer_interface.h b/src/trainer_interface.h index 6cd2469..552b206 100644 --- a/src/trainer_interface.h +++ b/src/trainer_interface.h @@ -88,13 +88,13 @@ class TrainerInterface { virtual ~TrainerInterface(); - virtual void SetSentenceIterator(SentenceIterator *sentence_iterator) { + // Loads sentence from `sentence_iterator` and stores the model + // to `output_model_proto`. + virtual util::Status Train(SentenceIterator *sentence_iterator, + ModelProto *output_model_proto) { sentence_iterator_ = sentence_iterator; - } - - virtual void SetOutputSerializedModelProto( - std::string *serialized_model_proto) { - serialized_model_proto_ = serialized_model_proto; + output_model_proto_ = output_model_proto; + return Train(); } virtual util::Status Train() { return status(); } @@ -158,7 +158,7 @@ class TrainerInterface { SentenceIterator *sentence_iterator_ = nullptr; // Emits model to this proto instead of file. - std::string *serialized_model_proto_ = nullptr; + ModelProto *output_model_proto_ = nullptr; private: // Serialize final_pieces_ to |model_proto|. diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 8f6cd4b..bd2d99b 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -578,7 +578,7 @@ bool Model::VerifyOutputsEquivalent(absl::string_view expected, } else { const int length = p.size(); total_score += IsUserDefinedInlined(id) - ? (length * max_score_ + 1.0) + ? (length * max_score_ - 0.1) : GetScoreInlined(id); } } @@ -688,7 +688,7 @@ EncodeResult Model::EncodeOptimized(absl::string_view normalized) const { const auto length = (key_pos - starts_at); // User defined symbol receives extra bonus to always be selected. const auto score = IsUserDefinedInlined(ret) - ? (length * max_score_ + 1.0) + ? (length * max_score_ - 0.1) : GetScoreInlined(ret); const auto candidate_best_path_score = score + best_path_score_till_here; diff --git a/src/unigram_model.h b/src/unigram_model.h index d67c7c7..df84260 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -127,6 +127,10 @@ class Model : public ModelInterface { EncodeResult SampleEncode(absl::string_view normalized, float theta) const override; + bool IsSampleEncodeAvailable() const override { return true; } + + bool IsNBestEncodeAvailable() const override { return true; } + // Returns the minimum score in sentence pieces. // min_score() - 10 is used for the cost of unknown sentence. float min_score() const { return min_score_; } -- cgit v1.2.3