Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-09-02 16:18:57 +0300
committerTaku Kudo <taku@google.com>2018-09-02 16:18:57 +0300
commitb6a74ee84c6920d5df7976159bc5180e15b6edec (patch)
tree37dcfa46ee78a015c2acc9928cd29611d83e76ca
parent4f5223b6457bde8a09d9cf0a8ea964f2636ac0a6 (diff)
Added self testing feature.
-rw-r--r--src/sentencepiece_model.proto19
-rw-r--r--src/sentencepiece_processor.cc26
-rw-r--r--src/sentencepiece_trainer_test.cc3
-rw-r--r--src/spm_train_main.cc3
-rw-r--r--src/trainer_interface.cc22
-rw-r--r--src/trainer_interface.h3
-rw-r--r--src/util.h31
7 files changed, 106 insertions, 1 deletions
diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto
index 03681b4..7bbc320 100644
--- a/src/sentencepiece_model.proto
+++ b/src/sentencepiece_model.proto
@@ -56,6 +56,9 @@ message TrainerSpec {
// Since the model is language-agnostic, this field is used as a reference.
repeated string accept_language = 5;
+ // Size of self-test samples, which are encoded in the model file.
+ optional int32 self_test_sample_size = 6 [ default = 0 ];
+
///////////////////////////////////////////////////////////////////
// Training parameters.
//
@@ -191,6 +194,19 @@ message NormalizerSpec {
extensions 200 to max;
}
+// Proto to store samples for self-testing.
+message SelfTestData {
+ message Sample {
+ optional string input = 1;
+ optional string expected = 2;
+ }
+ repeated Sample samples = 1;
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+}
+
// ModelProto stores model parameters.
// SentencePieceProcessor is supposed to be self-contained.
// All settings/parameters which may change the behavior must be encoded
@@ -224,6 +240,9 @@ message ModelProto {
// Spec for text normalization.
optional NormalizerSpec normalizer_spec = 3;
+ // Stores sample input and its expected segmentation to verify the model.
+ optional SelfTestData self_test_data = 4;
+
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 8d7139d..9bd41a0 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -80,7 +80,31 @@ util::Status SentencePieceProcessor::Load(
model_ = ModelFactory::Create(*model_proto_);
normalizer_ =
port::MakeUnique<normalizer::Normalizer>(model_proto_->normalizer_spec());
- return status();
+
+ RETURN_IF_ERROR(status());
+
+ // Running self-testing.
+ std::vector<std::string> errors, sps;
+ for (const auto &s : model_proto_->self_test_data().samples()) {
+ RETURN_IF_ERROR(Encode(s.input(), &sps));
+ const std::string result = string_util::Join(sps, " ");
+ if (s.expected() != result) {
+ errors.emplace_back(
+ string_util::StrCat(s.input(), "\t", s.expected(), "\t", result));
+ }
+ }
+
+ if (!errors.empty()) {
+ LOG(INFO) << errors.size() << "/"
+ << model_proto_->self_test_data().samples_size()
+ << " samples did not pass the test.";
+ for (const auto &e : errors) {
+ LOG(INFO) << e;
+ }
+ return util::InternalError("Self-test failures. See LOG(INFO).");
+ }
+
+ return util::OkStatus();
}
util::Status SentencePieceProcessor::SetEncodeExtraOptions(
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index 3fa57b8..d5a6d8d 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -28,6 +28,9 @@ TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000"));
SentencePieceTrainer::Train(string_util::StrCat(
+ "--input=", input,
+ " --model_prefix=m --vocab_size=1000 --self_test_sample_size=100"));
+ SentencePieceTrainer::Train(string_util::StrCat(
"--input=", input, " --model_prefix=m --vocab_size=1000 ",
"--model_type=bpe"));
SentencePieceTrainer::Train(string_util::StrCat(
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index 32e73a0..f786fc6 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -35,6 +35,8 @@ DEFINE_string(model_type, "unigram",
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
DEFINE_string(accept_language, "",
"comma-separated list of languages this model can accept");
+DEFINE_int32(self_test_sample_size, kDefaultTrainerSpec.self_test_sample_size(),
+ "the size of self test samples");
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
"character coverage to determine the minimum symbols");
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
@@ -112,6 +114,7 @@ int main(int argc, char *argv[]) {
SetTrainerSpecFromFlag(input_format);
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
+ SetTrainerSpecFromFlag(self_test_sample_size);
SetTrainerSpecFromFlag(character_coverage);
SetTrainerSpecFromFlag(input_sentence_size);
SetTrainerSpecFromFlag(mining_sentence_size);
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 36313b7..7eee112 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -58,6 +58,7 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
CHECK_RANGE(trainer_spec.num_sub_iterations(), 1, 10);
CHECK_RANGE(trainer_spec.num_threads(), 1, 128);
CHECK_RANGE(trainer_spec.seed_sentencepiece_size(), 1000, 5000000);
+ CHECK_RANGE(trainer_spec.self_test_sample_size(), 0, 1000);
CHECK_RANGE(trainer_spec.shrinking_factor(), 0.5, 0.95);
CHECK_RANGE(trainer_spec.training_sentence_size(), 100, 100000000);
#undef CHECK_RANGE
@@ -156,6 +157,9 @@ util::Status TrainerInterface::LoadSentences() {
for (const auto &it : meta_pieces_) meta_pieces_set.insert(it.second.first);
const PrefixMatcher meta_pieces_matcher(meta_pieces_set);
+ random::ReservoirSampler<std::string> sampler(
+ trainer_spec_.self_test_sample_size());
+
for (const auto &filename : trainer_spec_.input()) {
LOG(INFO) << "Loading corpus: " << filename;
std::string sentence;
@@ -200,6 +204,7 @@ util::Status TrainerInterface::LoadSentences() {
}
sentences_.emplace_back(normalized, freq);
+ sampler.Add(sentence);
if (sentences_.size() ==
static_cast<size_t>(trainer_spec_.input_sentence_size())) {
@@ -209,7 +214,10 @@ util::Status TrainerInterface::LoadSentences() {
}
END:
+ self_test_samples_ = sampler.sampled();
+
LOG(INFO) << "Loaded " << sentences_.size() << " sentences";
+ LOG(INFO) << "Loaded " << self_test_samples_.size() << " test sentences";
// Count character frequencies.
int64 all_chars_count = 0;
@@ -353,6 +361,20 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const {
LOG(INFO) << "Saving model: " << filename;
ModelProto model_proto;
RETURN_IF_ERROR(Serialize(&model_proto));
+
+ // Saves self-testing data.
+ if (!self_test_samples_.empty()) {
+ SentencePieceProcessor sp;
+ RETURN_IF_ERROR(sp.Load(model_proto));
+ for (const auto &input : self_test_samples_) {
+ std::vector<std::string> sps;
+ RETURN_IF_ERROR(sp.Encode(input, &sps));
+ auto *sample = model_proto.mutable_self_test_data()->add_samples();
+ sample->set_input(input);
+ sample->set_expected(string_util::Join(sps, " "));
+ }
+ }
+
auto output = filesystem::NewWritableFile(filename.data(), true);
RETURN_IF_ERROR(output->status());
output->Write(model_proto.SerializeAsString());
diff --git a/src/trainer_interface.h b/src/trainer_interface.h
index 5f3b1ae..5909a11 100644
--- a/src/trainer_interface.h
+++ b/src/trainer_interface.h
@@ -131,6 +131,9 @@ class TrainerInterface {
// Initializes `meta_pieces_` from TrainerSpec.
util::Status InitMetaPieces();
+
+ // Randomly sampled raw sentences for self-testing.
+ std::vector<std::string> self_test_samples_;
};
} // namespace sentencepiece
#endif // TRAINER_INTERFACE_H_
diff --git a/src/util.h b/src/util.h
index 4d940cf..05af79a 100644
--- a/src/util.h
+++ b/src/util.h
@@ -403,6 +403,37 @@ namespace random {
std::mt19937 *GetRandomGenerator();
+template <typename T>
+class ReservoirSampler {
+ public:
+ explicit ReservoirSampler(size_t size)
+ : size_(size), engine_(std::random_device{}()) {}
+ virtual ~ReservoirSampler() {}
+
+ void Add(const T &item) {
+ if (size_ == 0) return;
+
+ ++total_;
+ if (sampled_.size() < size_) {
+ sampled_.push_back(item);
+ } else {
+ std::uniform_int_distribution<> dist(0, total_ - 1);
+ const int n = dist(engine_);
+ if (n < static_cast<int>(sampled_.size())) {
+ sampled_[n] = item;
+ }
+ }
+ }
+
+ const std::vector<T> &sampled() const { return sampled_; }
+
+ private:
+ size_t size_ = 0;
+ size_t total_ = 0;
+ std::mt19937 engine_;
+ std::vector<T> sampled_;
+};
+
} // namespace random
namespace util {