Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-05-10 05:42:37 +0300
committerTaku Kudo <taku@google.com>2020-05-10 05:42:37 +0300
commit52cc641ac115e1ccc3f6f503ca2b3d532c06caec (patch)
tree7b6aadc4dabc298873bea0e92d28dfb1683df4f7 /src
parent8b921ac65a0f088618e6679595e655ff331a530f (diff)
Added spec_parser test cases.
Diffstat (limited to 'src')
-rw-r--r--src/sentencepiece_trainer.cc12
-rw-r--r--src/sentencepiece_trainer_test.cc45
-rw-r--r--src/spec_parser.h10
3 files changed, 55 insertions, 12 deletions
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index e110c72..5495422 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -60,10 +60,14 @@ util::Status SentencePieceTrainer::Train(
RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_denormalizer_spec, true));
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec,
copied_denormalizer_spec);
- std::string info = absl::StrCat(PrintProto(trainer_spec),
- PrintProto(copied_normalizer_spec));
- if (!copied_denormalizer_spec.precompiled_charsmap().empty())
- info += PrintProto(copied_denormalizer_spec);
+ std::string info =
+ absl::StrCat(PrintProto(trainer_spec, "trainer_spec"),
+ PrintProto(copied_normalizer_spec, "normalizer_spec"));
+ if (!copied_denormalizer_spec.precompiled_charsmap().empty()) {
+ info += PrintProto(copied_denormalizer_spec, "denormalizer_spec");
+ } else {
+ info += "denormalizer_spec {}";
+ }
LOG(INFO) << "Starts training with : \n" << info;
diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc
index c51bf37..b78b1d2 100644
--- a/src/sentencepiece_trainer_test.cc
+++ b/src/sentencepiece_trainer_test.cc
@@ -36,6 +36,18 @@ void CheckVocab(absl::string_view filename, int expected_vocab_size) {
sp.model_proto().trainer_spec().vocab_size());
}
+void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer,
+ bool expected_has_denormalizer) {
+ SentencePieceProcessor sp;
+ ASSERT_TRUE(sp.Load(filename.data()).ok());
+ const auto &normalizer_spec = sp.model_proto().normalizer_spec();
+ const auto &denormalizer_spec = sp.model_proto().denormalizer_spec();
+ EXPECT_EQ(!normalizer_spec.precompiled_charsmap().empty(),
+ expected_has_normalizer);
+ EXPECT_EQ(!denormalizer_spec.precompiled_charsmap().empty(),
+ expected_has_denormalizer);
+}
+
TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
const std::string input = util::JoinPath(FLAGS_test_srcdir, kTestData);
const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
@@ -119,6 +131,7 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
absl::StrCat("--model_prefix=", model, " --vocab_size=1000"), &it)
.ok());
CheckVocab(model + ".model", 1000);
+ CheckNormalizer(model + ".model", true, false);
}
TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
@@ -126,10 +139,12 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
std::string rule = util::JoinPath(FLAGS_test_srcdir, kNfkcTestData);
const std::string model = util::JoinPath(FLAGS_test_tmpdir, "m");
- SentencePieceTrainer::Train(
- absl::StrCat("--input=", input, " --model_prefix=", model,
- "--vocab_size=1000 ", "--normalization_rule_tsv=", rule))
- .IgnoreError();
+ EXPECT_TRUE(SentencePieceTrainer::Train(
+ absl::StrCat("--input=", input, " --model_prefix=", model,
+ " --vocab_size=1000 ",
+ "--normalization_rule_tsv=", rule))
+ .ok());
+ CheckNormalizer(model + ".model", true, false);
}
TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) {
@@ -147,6 +162,7 @@ TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) {
norm_rule_tsv,
" --denormalization_rule_tsv=", denorm_rule_tsv))
.ok());
+ CheckNormalizer(model + ".model", true, true);
}
TEST(SentencePieceTrainerTest, TrainErrorTest) {
@@ -212,6 +228,15 @@ TEST(SentencePieceTrainerTest, SetProtoFieldTest) {
EXPECT_EQ("bar", spec.input(1));
EXPECT_EQ("buz", spec.input(2));
+ // CSV
+ spec.Clear();
+ ASSERT_TRUE(
+ SentencePieceTrainer::SetProtoField("input", "\"foo,bar\",buz", &spec)
+ .ok());
+ EXPECT_EQ(2, spec.input_size());
+ EXPECT_EQ("foo,bar", spec.input(0));
+ EXPECT_EQ("buz", spec.input(1));
+
ASSERT_TRUE(
SentencePieceTrainer::SetProtoField("model_type", "BPE", &spec).ok());
EXPECT_FALSE(
@@ -278,6 +303,18 @@ TEST(SentencePieceTrainerTest, MergeSpecsFromArgs) {
.ok());
EXPECT_EQ("foo", normalizer_spec.name());
+ ASSERT_TRUE(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--normalization_rule_tsv=foo.tsv", &trainer_spec,
+ &normalizer_spec, &denormalizer_spec)
+ .ok());
+ EXPECT_EQ("foo.tsv", normalizer_spec.normalization_rule_tsv());
+
+ ASSERT_TRUE(SentencePieceTrainer::MergeSpecsFromArgs(
+ "--denormalization_rule_tsv=bar.tsv", &trainer_spec,
+ &normalizer_spec, &denormalizer_spec)
+ .ok());
+ EXPECT_EQ("bar.tsv", denormalizer_spec.normalization_rule_tsv());
+
EXPECT_FALSE(SentencePieceTrainer::MergeSpecsFromArgs(
"--vocab_size=UNK", &trainer_spec, &normalizer_spec,
&denormalizer_spec)
diff --git a/src/spec_parser.h b/src/spec_parser.h
index c4a0620..729e036 100644
--- a/src/spec_parser.h
+++ b/src/spec_parser.h
@@ -100,10 +100,11 @@ namespace sentencepiece {
else \
os << " " << #param_name << ": " << it->second << "\n";
-inline std::string PrintProto(const TrainerSpec &message) {
+inline std::string PrintProto(const TrainerSpec &message,
+ absl::string_view name) {
std::ostringstream os;
- os << "TrainerSpec {\n";
+ os << name << " {\n";
PRINT_REPEATED_STRING(input);
PRINT_PARAM(input_format);
@@ -157,10 +158,11 @@ inline std::string PrintProto(const TrainerSpec &message) {
return os.str();
}
-inline std::string PrintProto(const NormalizerSpec &message) {
+inline std::string PrintProto(const NormalizerSpec &message,
+ absl::string_view name) {
std::ostringstream os;
- os << "NormalizerSpec {\n";
+ os << name << " {\n";
PRINT_PARAM(name);
PRINT_PARAM(add_dummy_prefix);