diff options
author | Taku Kudo <taku910@users.noreply.github.com> | 2018-09-02 21:47:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-02 21:47:33 +0300 |
commit | 03dad83922588a92c04cca8cb770187034cf969b (patch) | |
tree | 37dcfa46ee78a015c2acc9928cd29611d83e76ca | |
parent | 4f5223b6457bde8a09d9cf0a8ea964f2636ac0a6 (diff) | |
parent | b6a74ee84c6920d5df7976159bc5180e15b6edec (diff) |
Merge pull request #196 from google/sr
Added self testing feature.
-rw-r--r-- | src/sentencepiece_model.proto | 19 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 26 | ||||
-rw-r--r-- | src/sentencepiece_trainer_test.cc | 3 | ||||
-rw-r--r-- | src/spm_train_main.cc | 3 | ||||
-rw-r--r-- | src/trainer_interface.cc | 22 | ||||
-rw-r--r-- | src/trainer_interface.h | 3 | ||||
-rw-r--r-- | src/util.h | 31 |
7 files changed, 106 insertions, 1 deletions
diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index 03681b4..7bbc320 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -56,6 +56,9 @@ message TrainerSpec { // Since the model is language-agnostic, this field is used as a reference. repeated string accept_language = 5; + // Size of self-test samples, which are encoded in the model file. + optional int32 self_test_sample_size = 6 [ default = 0 ]; + /////////////////////////////////////////////////////////////////// // Training parameters. // @@ -191,6 +194,19 @@ message NormalizerSpec { extensions 200 to max; } +// Proto to store samples for self-testing. +message SelfTestData { + message Sample { + optional string input = 1; + optional string expected = 2; + } + repeated Sample samples = 1; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + // ModelProto stores model parameters. // SentencePieceProcessor is supposed to be self-contained. // All settings/parameters which may change the behavior must be encoded @@ -224,6 +240,9 @@ message ModelProto { // Spec for text normalization. optional NormalizerSpec normalizer_spec = 3; + // Stores sample input and its expected segmentation to verify the model. + optional SelfTestData self_test_data = 4; + // Customized extensions: the range of field numbers // are open to third-party extensions. extensions 200 to max; diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 8d7139d..9bd41a0 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -80,7 +80,31 @@ util::Status SentencePieceProcessor::Load( model_ = ModelFactory::Create(*model_proto_); normalizer_ = port::MakeUnique<normalizer::Normalizer>(model_proto_->normalizer_spec()); - return status(); + + RETURN_IF_ERROR(status()); + + // Running self-testing. + std::vector<std::string> errors, sps; + for (const auto &s : model_proto_->self_test_data().samples()) { + RETURN_IF_ERROR(Encode(s.input(), &sps)); + const std::string result = string_util::Join(sps, " "); + if (s.expected() != result) { + errors.emplace_back( + string_util::StrCat(s.input(), "\t", s.expected(), "\t", result)); + } + } + + if (!errors.empty()) { + LOG(INFO) << errors.size() << "/" + << model_proto_->self_test_data().samples_size() + << " samples did not pass the test."; + for (const auto &e : errors) { + LOG(INFO) << e; + } + return util::InternalError("Self-test failures. See LOG(INFO)."); + } + + return util::OkStatus(); } util::Status SentencePieceProcessor::SetEncodeExtraOptions( diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index 3fa57b8..d5a6d8d 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -28,6 +28,9 @@ TEST(SentencePieceTrainerTest, TrainFromArgsTest) { SentencePieceTrainer::Train(string_util::StrCat( "--input=", input, " --model_prefix=m --vocab_size=1000")); SentencePieceTrainer::Train(string_util::StrCat( + "--input=", input, + " --model_prefix=m --vocab_size=1000 --self_test_sample_size=100")); + SentencePieceTrainer::Train(string_util::StrCat( "--input=", input, " --model_prefix=m --vocab_size=1000 ", "--model_type=bpe")); SentencePieceTrainer::Train(string_util::StrCat( diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index 32e73a0..f786fc6 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -35,6 +35,8 @@ DEFINE_string(model_type, "unigram", DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size"); DEFINE_string(accept_language, "", "comma-separated list of languages this model can accept"); +DEFINE_int32(self_test_sample_size, kDefaultTrainerSpec.self_test_sample_size(), + "the size of self test samples"); DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(), "character coverage to determine the minimum symbols"); DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(), @@ -112,6 +114,7 @@ int main(int argc, char *argv[]) { SetTrainerSpecFromFlag(input_format); SetTrainerSpecFromFlag(model_prefix); SetTrainerSpecFromFlag(vocab_size); + SetTrainerSpecFromFlag(self_test_sample_size); SetTrainerSpecFromFlag(character_coverage); SetTrainerSpecFromFlag(input_sentence_size); SetTrainerSpecFromFlag(mining_sentence_size); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 36313b7..7eee112 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -58,6 +58,7 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) { CHECK_RANGE(trainer_spec.num_sub_iterations(), 1, 10); CHECK_RANGE(trainer_spec.num_threads(), 1, 128); CHECK_RANGE(trainer_spec.seed_sentencepiece_size(), 1000, 5000000); + CHECK_RANGE(trainer_spec.self_test_sample_size(), 0, 1000); CHECK_RANGE(trainer_spec.shrinking_factor(), 0.5, 0.95); CHECK_RANGE(trainer_spec.training_sentence_size(), 100, 100000000); #undef CHECK_RANGE @@ -156,6 +157,9 @@ util::Status TrainerInterface::LoadSentences() { for (const auto &it : meta_pieces_) meta_pieces_set.insert(it.second.first); const PrefixMatcher meta_pieces_matcher(meta_pieces_set); + random::ReservoirSampler<std::string> sampler( + trainer_spec_.self_test_sample_size()); + for (const auto &filename : trainer_spec_.input()) { LOG(INFO) << "Loading corpus: " << filename; std::string sentence; @@ -200,6 +204,7 @@ util::Status TrainerInterface::LoadSentences() { } sentences_.emplace_back(normalized, freq); + sampler.Add(sentence); if (sentences_.size() == static_cast<size_t>(trainer_spec_.input_sentence_size())) { @@ -209,7 +214,10 @@ util::Status TrainerInterface::LoadSentences() { } END: + self_test_samples_ = sampler.sampled(); + LOG(INFO) << "Loaded " << sentences_.size() << " sentences"; + LOG(INFO) << "Loaded " << self_test_samples_.size() << " test sentences"; // Count character frequencies. int64 all_chars_count = 0; @@ -353,6 +361,20 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const { LOG(INFO) << "Saving model: " << filename; ModelProto model_proto; RETURN_IF_ERROR(Serialize(&model_proto)); + + // Saves self-testing data. + if (!self_test_samples_.empty()) { + SentencePieceProcessor sp; + RETURN_IF_ERROR(sp.Load(model_proto)); + for (const auto &input : self_test_samples_) { + std::vector<std::string> sps; + RETURN_IF_ERROR(sp.Encode(input, &sps)); + auto *sample = model_proto.mutable_self_test_data()->add_samples(); + sample->set_input(input); + sample->set_expected(string_util::Join(sps, " ")); + } + } + auto output = filesystem::NewWritableFile(filename.data(), true); RETURN_IF_ERROR(output->status()); output->Write(model_proto.SerializeAsString()); diff --git a/src/trainer_interface.h b/src/trainer_interface.h index 5f3b1ae..5909a11 100644 --- a/src/trainer_interface.h +++ b/src/trainer_interface.h @@ -131,6 +131,9 @@ class TrainerInterface { // Initializes `meta_pieces_` from TrainerSpec. util::Status InitMetaPieces(); + + // Randomly sampled raw sentences for self-testing. + std::vector<std::string> self_test_samples_; }; } // namespace sentencepiece #endif // TRAINER_INTERFACE_H_ @@ -403,6 +403,37 @@ namespace random { std::mt19937 *GetRandomGenerator(); +template <typename T> +class ReservoirSampler { + public: + explicit ReservoirSampler(size_t size) + : size_(size), engine_(std::random_device{}()) {} + virtual ~ReservoirSampler() {} + + void Add(const T &item) { + if (size_ == 0) return; + + ++total_; + if (sampled_.size() < size_) { + sampled_.push_back(item); + } else { + std::uniform_int_distribution<> dist(0, total_ - 1); + const int n = dist(engine_); + if (n < static_cast<int>(sampled_.size())) { + sampled_[n] = item; + } + } + } + + const std::vector<T> &sampled() const { return sampled_; } + + private: + size_t size_ = 0; + size_t total_ = 0; + std::mt19937 engine_; + std::vector<T> sampled_; +}; + } // namespace random namespace util { |