diff options
author | Taku Kudo <taku@google.com> | 2020-05-10 05:42:37 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-05-10 05:42:37 +0300 |
commit | 52cc641ac115e1ccc3f6f503ca2b3d532c06caec (patch) | |
tree | 7b6aadc4dabc298873bea0e92d28dfb1683df4f7 /src | |
parent | 8b921ac65a0f088618e6679595e655ff331a530f (diff) |
Added spec_parser test cases.
Diffstat (limited to 'src')
-rw-r--r-- | src/sentencepiece_trainer.cc | 12 | ||||
-rw-r--r-- | src/sentencepiece_trainer_test.cc | 45 | ||||
-rw-r--r-- | src/spec_parser.h | 10 |
3 files changed, 55 insertions, 12 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index e110c72..5495422 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -60,10 +60,14 @@ util::Status SentencePieceTrainer::Train( RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_denormalizer_spec, true)); auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec, copied_denormalizer_spec); - std::string info = absl::StrCat(PrintProto(trainer_spec), - PrintProto(copied_normalizer_spec)); - if (!copied_denormalizer_spec.precompiled_charsmap().empty()) - info += PrintProto(copied_denormalizer_spec); + std::string info = + absl::StrCat(PrintProto(trainer_spec, "trainer_spec"), + PrintProto(copied_normalizer_spec, "normalizer_spec")); + if (!copied_denormalizer_spec.precompiled_charsmap().empty()) { + info += PrintProto(copied_denormalizer_spec, "denormalizer_spec"); + } else { + info += "denormalizer_spec {}"; + } LOG(INFO) << "Starts training with : \n" << info; diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index c51bf37..b78b1d2 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -36,6 +36,18 @@ void CheckVocab(absl::string_view filename, int expected_vocab_size) { sp.model_proto().trainer_spec().vocab_size()); } +void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer, + bool expected_has_denormalizer) { + SentencePieceProcessor sp; + ASSERT_TRUE(sp.Load(filename.data()).ok()); + const auto &normalizer_spec = sp.model_proto().normalizer_spec(); + const auto &denormalizer_spec = sp.model_proto().denormalizer_spec(); + EXPECT_EQ(!normalizer_spec.precompiled_charsmap().empty(), + expected_has_normalizer); + EXPECT_EQ(!denormalizer_spec.precompiled_charsmap().empty(), + expected_has_denormalizer); +} + TEST(SentencePieceTrainerTest, TrainFromArgsTest) { const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestData); const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m"); @@ -119,6 +131,7 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) { absl::StrCat("--model_prefix=", model, " --vocab_size=1000"), &it) .ok()); CheckVocab(model + ".model", 1000); + CheckNormalizer(model + ".model", true, false); } TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) { @@ -126,10 +139,12 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) { std::string rule = util::JoinPath(FLAGS_test_srcdir, kNfkcTestData); const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m"); - SentencePieceTrainer::Train( - absl::StrCat("--input=", input, " --model_prefix=", model, - "--vocab_size=1000 ", "--normalization_rule_tsv=", rule)) - .IgnoreError(); + EXPECT_TRUE(SentencePieceTrainer::Train( + absl::StrCat("--input=", input, " --model_prefix=", model, + " --vocab_size=1000 ", + "--normalization_rule_tsv=", rule)) + .ok()); + CheckNormalizer(model + ".model", true, false); } TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) { @@ -147,6 +162,7 @@ TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) { norm_rule_tsv, " --denormalization_rule_tsv=", denorm_rule_tsv)) .ok()); + CheckNormalizer(model + ".model", true, true); } TEST(SentencePieceTrainerTest, TrainErrorTest) { @@ -212,6 +228,15 @@ TEST(SentencePieceTrainerTest, SetProtoFieldTest) { EXPECT_EQ("bar", spec.input(1)); EXPECT_EQ("buz", spec.input(2)); + // CSV + spec.Clear(); + ASSERT_TRUE( + SentencePieceTrainer::SetProtoField("input", "\"foo,bar\",buz", &spec) + .ok()); + EXPECT_EQ(2, spec.input_size()); + EXPECT_EQ("foo,bar", spec.input(0)); + EXPECT_EQ("buz", spec.input(1)); + ASSERT_TRUE( SentencePieceTrainer::SetProtoField("model_type", "BPE", &spec).ok()); EXPECT_FALSE( @@ -278,6 +303,18 @@ TEST(SentencePieceTrainerTest, MergeSpecsFromArgs) { .ok()); EXPECT_EQ("foo", normalizer_spec.name()); + ASSERT_TRUE(SentencePieceTrainer::MergeSpecsFromArgs( + "--normalization_rule_tsv=foo.tsv", &trainer_spec, + &normalizer_spec, &denormalizer_spec) + .ok()); + EXPECT_EQ("foo.tsv", normalizer_spec.normalization_rule_tsv()); + + ASSERT_TRUE(SentencePieceTrainer::MergeSpecsFromArgs( + "--denormalization_rule_tsv=bar.tsv", &trainer_spec, + &normalizer_spec, &denormalizer_spec) + .ok()); + EXPECT_EQ("bar.tsv", denormalizer_spec.normalization_rule_tsv()); + EXPECT_FALSE(SentencePieceTrainer::MergeSpecsFromArgs( "--vocab_size=UNK", &trainer_spec, &normalizer_spec, &denormalizer_spec) diff --git a/src/spec_parser.h b/src/spec_parser.h index c4a0620..729e036 100644 --- a/src/spec_parser.h +++ b/src/spec_parser.h @@ -100,10 +100,11 @@ namespace sentencepiece { else \ os << " " << #param_name << ": " << it->second << "\n"; -inline std::string PrintProto(const TrainerSpec &message) { +inline std::string PrintProto(const TrainerSpec &message, + absl::string_view name) { std::ostringstream os; - os << "TrainerSpec {\n"; + os << name << " {\n"; PRINT_REPEATED_STRING(input); PRINT_PARAM(input_format); @@ -157,10 +158,11 @@ inline std::string PrintProto(const TrainerSpec &message) { return os.str(); } -inline std::string PrintProto(const NormalizerSpec &message) { +inline std::string PrintProto(const NormalizerSpec &message, + absl::string_view name) { std::ostringstream os; - os << "NormalizerSpec {\n"; + os << name << " {\n"; PRINT_PARAM(name); PRINT_PARAM(add_dummy_prefix); |