Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-05-10 18:40:42 +0300
committerTaku Kudo <taku@google.com>2018-05-10 18:40:42 +0300
commit54210ca31e1489950acbaf7cc4f449c38940b643 (patch)
tree10e7cd886f4c2d9cc3c9bcb03a32cae936fb8ad3 /src
parentd9469a14f96150f5bc94a4c159452b2d50618986 (diff)
CHECK to util::Status migration for Builder
Diffstat (limited to 'src')
-rw-r--r--src/bpe_model_trainer_test.cc10
-rw-r--r--src/builder.cc61
-rw-r--r--src/builder.h20
-rw-r--r--src/builder_test.cc27
-rw-r--r--src/char_model_trainer_test.cc5
-rw-r--r--src/compile_charsmap_main.cc5
-rw-r--r--src/normalizer_test.cc14
-rw-r--r--src/sentencepiece_processor_test.cc7
-rw-r--r--src/sentencepiece_trainer.cc30
-rw-r--r--src/spm_normalize_main.cc21
-rw-r--r--src/unigram_model_trainer_test.cc6
-rw-r--r--src/word_model_trainer_test.cc4
12 files changed, 122 insertions, 88 deletions
diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc
index 0610ae5..8e4c675 100644
--- a/src/bpe_model_trainer_test.cc
+++ b/src/bpe_model_trainer_test.cc
@@ -44,8 +44,10 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
trainer_spec.set_model_prefix(model_prefix);
- auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("identity");
normalizer_spec.set_add_dummy_prefix(false);
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
@@ -76,10 +78,12 @@ TEST(BPETrainerTest, BasicTest) {
TEST(BPETrainerTest, EndToEndTest) {
TrainerSpec trainer_spec;
- NormalizerSpec normalizer_spec;
- normalizer_spec = normalizer::Builder::GetNormalizerSpec("nfkc");
trainer_spec.add_input("../data/wagahaiwa_nekodearu.txt");
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("nfkc");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
+
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);
trainer_spec.set_model_type(TrainerSpec::BPE);
diff --git a/src/builder.cc b/src/builder.cc
index 22e4e66..c94832b 100644
--- a/src/builder.cc
+++ b/src/builder.cc
@@ -33,6 +33,7 @@
namespace sentencepiece {
namespace normalizer {
namespace {
+static constexpr char kDefaultNormalizerName[] = "nfkc";
#ifdef ENABLE_NFKC_COMPILE
// Normalize |input| with ICU's normalizer with |mode|.
@@ -140,8 +141,10 @@ Builder::Chars Normalize(const Builder::CharsMap &chars_map,
} // namespace
// static
-std::string Builder::CompileCharsMap(const CharsMap &chars_map) {
- CHECK(!chars_map.empty());
+util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
+ std::string *output) {
+ CHECK_OR_RETURN(output);
+ CHECK_OR_RETURN(!chars_map.empty());
LOG(INFO) << "Loading CharsMap of size " << chars_map.size();
@@ -175,9 +178,8 @@ std::string Builder::CompileCharsMap(const CharsMap &chars_map) {
}
Darts::DoubleArray trie;
- CHECK_EQ(
- 0,
- trie.build(key.size(), const_cast<char **>(&key[0]), nullptr, &value[0]))
+ CHECK_EQ_OR_RETURN(0, trie.build(key.size(), const_cast<char **>(&key[0]),
+ nullptr, &value[0]))
<< "cannot build double-array";
int max_nodes_size = 0;
@@ -188,41 +190,60 @@ std::string Builder::CompileCharsMap(const CharsMap &chars_map) {
results.size(), strlen(str));
max_nodes_size = std::max(num_nodes, max_nodes_size);
}
- CHECK_LT(max_nodes_size, Normalizer::kMaxTrieResultsSize)
+ CHECK_LT_OR_RETURN(max_nodes_size, Normalizer::kMaxTrieResultsSize)
<< "This charmaps contain many shared prefix. "
<< "The number of shared prefix must be less than "
<< Normalizer::kMaxTrieResultsSize;
StringPiece trie_blob(static_cast<const char *>(trie.array()),
trie.size() * trie.unit_size());
- const std::string blob =
- Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
+ *output = Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
- LOG(INFO) << "Generated normalizer blob. size= " << blob.size();
+ LOG(INFO) << "Generated normalizer blob. size= " << output->size();
- return blob;
+ return util::OkStatus();
}
// static
-std::string Builder::GetPrecompiledCharsMap(const std::string &name) {
+util::Status Builder::GetPrecompiledCharsMap(const std::string &name,
+ std::string *output) {
std::string result;
for (size_t i = 0; i < kNormalizationRules_size; ++i) {
const auto *blob = &kNormalizationRules_blob[i];
if (blob->name == name) {
- result.assign(blob->data, blob->size);
- return result;
+ output->assign(blob->data, blob->size);
+ return util::OkStatus();
}
}
- LOG(FATAL) << "No precompiled charsmap is found: " << name;
- return result;
+ return util::StatusBuilder(util::error::NOT_FOUND)
+ << "No precompiled charsmap is found: " << name;
}
// static
-NormalizerSpec Builder::GetNormalizerSpec(const std::string &name) {
- NormalizerSpec spec;
- spec.set_name(name);
- spec.set_precompiled_charsmap(GetPrecompiledCharsMap(name));
- return spec;
+util::Status Builder::PopulateNormalizationSpec(
+ NormalizerSpec *normalizer_spec) {
+ CHECK_OR_RETURN(normalizer_spec);
+
+ if (!normalizer_spec->normalization_rule_tsv().empty()) {
+ CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty())
+ << "precompiled_charsmap is already defined.";
+ const auto chars_map = normalizer::Builder::BuildMapFromFile(
+ normalizer_spec->normalization_rule_tsv());
+ RETURN_IF_ERROR(CompileCharsMap(
+ chars_map, normalizer_spec->mutable_precompiled_charsmap()));
+ normalizer_spec->set_name("user_defined");
+ } else {
+ if (normalizer_spec->name().empty()) {
+ normalizer_spec->set_name(kDefaultNormalizerName);
+ }
+ if (normalizer_spec->precompiled_charsmap().empty()) {
+ RETURN_IF_ERROR(GetPrecompiledCharsMap(
+ normalizer_spec->name(),
+ normalizer_spec->mutable_precompiled_charsmap()));
+ }
+ }
+
+ return util::OkStatus();
}
// static
diff --git a/src/builder.h b/src/builder.h
index d9ae2bc..6c1d6fe 100644
--- a/src/builder.h
+++ b/src/builder.h
@@ -20,6 +20,7 @@
#include <vector>
#include "common.h"
#include "sentencepiece_model.pb.h"
+#include "sentencepiece_processor.h"
#include "stringpiece.h"
namespace sentencepiece {
@@ -41,14 +42,17 @@ class Builder {
// String-to-string mapping.
using CharsMap = std::map<Chars, Chars>;
- // Compiles |chars_map| into a binary index.
- static std::string CompileCharsMap(const CharsMap &chars_map);
+ static util::Status CompileCharsMap(const CharsMap &chars_map,
+ std::string *output);
- // Returns a pre-compiled binary index with |name|.
- static std::string GetPrecompiledCharsMap(const std::string &name);
+ // Returns a pre-compiled binary index with `name`.
+ static util::Status GetPrecompiledCharsMap(const std::string &name,
+ std::string *output);
- // Returns a normalizer spec with a binary index |name|.
- static NormalizerSpec GetNormalizerSpec(const std::string &name);
+ // Populates necessary fields (precompiled_charmap) from
+ // `name` or `normalization_rule_tsv` fields in `normalizer_spec`.
+ static util::Status PopulateNormalizationSpec(
+ NormalizerSpec *normalizer_spec);
// Makes a normalization mapping based on NFKC.
//
@@ -90,7 +94,7 @@ class Builder {
// Returns identity mapping, which dose not perform any normalization.
static CharsMap BuildIdentityMap();
- // Builds Chars map save in |filename|.
+ // Builds Chars map save in `filename`.
// Format:
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
// (src|trg)_ucharX must be a hex of UCS4.
@@ -99,7 +103,7 @@ class Builder {
private:
FRIEND_TEST(BuilderTest, RemoveRedundantMapTest);
- // Removes redundant rules from |chars_map|.
+ // Removes redundant rules from `chars_map`.
// When char_maps have "aa" => "bb" and "a" => "b", the first
// rule is not necessary since the second rule can cover the first rule.
static CharsMap RemoveRedundantMap(const CharsMap &chars_map);
diff --git a/src/builder_test.cc b/src/builder_test.cc
index 167e602..48e1aa6 100644
--- a/src/builder_test.cc
+++ b/src/builder_test.cc
@@ -42,8 +42,9 @@ TEST(BuilderTest, RemoveRedundantMapTest) {
}
TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
- EXPECT_DEATH(Builder::GetPrecompiledCharsMap(""));
- EXPECT_DEATH(Builder::GetPrecompiledCharsMap("__UNKNOWN__"));
+ std::string output;
+ EXPECT_NOT_OK(Builder::GetPrecompiledCharsMap("", &output));
+ EXPECT_NOT_OK(Builder::GetPrecompiledCharsMap("__UNKNOWN__", &output));
}
TEST(BuilderTest, BuildIdentityMapTest) {
@@ -62,8 +63,9 @@ TEST(BuilderTest, BuildNFKCMapTest) {
TEST(BuilderTest, GetPrecompiledCharsMapTest) {
{
- const NormalizerSpec spec = Builder::GetNormalizerSpec("nfkc");
-
+ NormalizerSpec spec;
+ spec.set_name("nfkc");
+ EXPECT_OK(Builder::PopulateNormalizationSpec(&spec));
const Normalizer normalizer(spec);
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
EXPECT_EQ(WS "(株)", normalizer.Normalize("㈱"));
@@ -71,8 +73,9 @@ TEST(BuilderTest, GetPrecompiledCharsMapTest) {
}
{
- const NormalizerSpec spec = Builder::GetNormalizerSpec("identity");
-
+ NormalizerSpec spec;
+ spec.set_name("identity");
+ EXPECT_OK(Builder::PopulateNormalizationSpec(&spec));
const Normalizer normalizer(spec);
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
EXPECT_EQ(WS "㈱", normalizer.Normalize("㈱"));
@@ -94,7 +97,8 @@ TEST(BuilderTest, CompileCharsMap) {
chars_map[{0x3042, 0x3044, 0x3046}] = {0x0061, 0x0062, 0x0063};
NormalizerSpec spec;
- spec.set_precompiled_charsmap(Builder::CompileCharsMap(chars_map));
+ EXPECT_OK(
+ Builder::CompileCharsMap(chars_map, spec.mutable_precompiled_charsmap()));
spec.set_add_dummy_prefix(false);
const Normalizer normalizer(spec);
@@ -110,8 +114,10 @@ TEST(BuilderTest, CompileCharsMap) {
TEST(BuilderTest, BuildMapFromFileTest) {
const auto cmap = Builder::BuildMapFromFile("../data/nfkc.tsv");
- const auto precompiled = Builder::CompileCharsMap(cmap);
- EXPECT_EQ(Builder::GetPrecompiledCharsMap("nfkc"), precompiled);
+ std::string expected, precompiled;
+ EXPECT_OK(Builder::CompileCharsMap(cmap, &precompiled));
+ EXPECT_OK(Builder::GetPrecompiledCharsMap("nfkc", &expected));
+ EXPECT_EQ(expected, precompiled);
}
TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
@@ -122,7 +128,8 @@ TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
keys.push_back('a');
chars_map[keys] = {'b'};
}
- EXPECT_DEATH(Builder::CompileCharsMap(chars_map));
+ std::string output;
+ EXPECT_NOT_OK(Builder::CompileCharsMap(chars_map, &output));
}
} // namespace normalizer
diff --git a/src/char_model_trainer_test.cc b/src/char_model_trainer_test.cc
index f577748..943dc9a 100644
--- a/src/char_model_trainer_test.cc
+++ b/src/char_model_trainer_test.cc
@@ -43,8 +43,9 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
trainer_spec.set_vocab_size(size);
trainer_spec.set_model_prefix(model_prefix);
- auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
- normalizer_spec.set_add_dummy_prefix(true);
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("identity");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc
index 6668f6b..ed8837a 100644
--- a/src/compile_charsmap_main.cc
+++ b/src/compile_charsmap_main.cc
@@ -76,7 +76,7 @@ std::string ToHexData(StringPiece data) {
return os.str();
}
} // namespace
-} // sentencepiece
+} // namespace sentencepiece
int main(int argc, char **argv) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
@@ -108,7 +108,8 @@ constexpr BinaryBlob kNormalizationRules_blob[] = {)";
for (const auto &p : kRuleList) {
const auto normalized_map = p.second();
- const auto index = Builder::CompileCharsMap(normalized_map);
+ std::string index;
+ CHECK_OK(Builder::CompileCharsMap(normalized_map, &index));
os << "{ \"" << p.first << "\", " << index.size() << ",\n";
os << sentencepiece::ToHexData(index);
os << " },";
diff --git a/src/normalizer_test.cc b/src/normalizer_test.cc
index f91bd7d..ea68074 100644
--- a/src/normalizer_test.cc
+++ b/src/normalizer_test.cc
@@ -23,7 +23,12 @@ namespace {
// Space symbol
#define WS "\xe2\x96\x81"
-NormalizerSpec MakeDefaultSpec() { return Builder::GetNormalizerSpec("nfkc"); }
+NormalizerSpec MakeDefaultSpec() {
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("nfkc");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
+ return normalizer_spec;
+}
} // namespace
TEST(NormalizerTest, NormalizeErrorTest) {
@@ -140,7 +145,8 @@ TEST(NormalizeTest, NomalizeWithSpaceContainedRules) {
AddRule("d", " F G ");
NormalizerSpec spec;
- spec.set_precompiled_charsmap(Builder::CompileCharsMap(charsmap));
+ EXPECT_OK(
+ Builder::CompileCharsMap(charsmap, spec.mutable_precompiled_charsmap()));
// Test default behavior
{
@@ -297,8 +303,8 @@ TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) {
EXPECT_EQ("foo", trie_blob);
EXPECT_EQ("bar", normalized_blob);
- EXPECT_NOT_OK(Normalizer::DecodePrecompiledCharsMap("", &trie_blob,
- &normalized_blob));
+ EXPECT_NOT_OK(
+ Normalizer::DecodePrecompiledCharsMap("", &trie_blob, &normalized_blob));
}
TEST(NormalizerTest, StatusTest) {
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 64e2e53..cc766d5 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -100,7 +100,10 @@ std::vector<std::string> GetSpVec(const SentencePieceText &spt) {
}
NormalizerSpec MakeDefaultNormalizerSpec() {
- return normalizer::Builder::GetNormalizerSpec("nfkc");
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("nfkc");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
+ return normalizer_spec;
}
TEST(SentencepieceProcessorTest, StatusTest) {
@@ -512,7 +515,7 @@ void AddPiece(ModelProto *model_proto, StringPiece piece, float score = 0.0) {
TEST(SentencePieceProcessorTest, LoadInvalidModelTest) {
SentencePieceProcessor sp;
- std::istream* stream = nullptr;
+ std::istream *stream = nullptr;
EXPECT_NOT_OK(sp.Load(stream));
EXPECT_NOT_OK(sp.Load(""));
EXPECT_NOT_OK(sp.Load("__UNKNOWN_FILE__"));
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index c9ac133..7513f4c 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -25,9 +25,6 @@
#include "util.h"
namespace sentencepiece {
-namespace {
-static constexpr char kDefaultNormalizerName[] = "nfkc";
-} // namespace
// static
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
@@ -39,32 +36,11 @@ util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
util::Status SentencePieceTrainer::Train(
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
auto copied_normalizer_spec = normalizer_spec;
-
- if (!copied_normalizer_spec.normalization_rule_tsv().empty()) {
- CHECK_OR_RETURN(copied_normalizer_spec.precompiled_charsmap().empty())
- << "precompiled_charsmap is already defined.";
-
- const auto chars_map = normalizer::Builder::BuildMapFromFile(
- copied_normalizer_spec.normalization_rule_tsv());
- copied_normalizer_spec.set_precompiled_charsmap(
- normalizer::Builder::CompileCharsMap(chars_map));
- copied_normalizer_spec.set_name("user_defined");
- } else {
- if (copied_normalizer_spec.name().empty()) {
- copied_normalizer_spec.set_name(kDefaultNormalizerName);
- }
-
- if (copied_normalizer_spec.precompiled_charsmap().empty()) {
- *(copied_normalizer_spec.mutable_precompiled_charsmap()) =
- normalizer::Builder::GetNormalizerSpec(copied_normalizer_spec.name())
- .precompiled_charsmap();
- }
- }
+ RETURN_IF_ERROR(
+ normalizer::Builder::PopulateNormalizationSpec(&copied_normalizer_spec));
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
- RETURN_IF_ERROR(trainer->Train());
-
- return util::OkStatus();
+ return trainer->Train();
}
// static
diff --git a/src/spm_normalize_main.cc b/src/spm_normalize_main.cc
index c64b6a2..4267576 100644
--- a/src/spm_normalize_main.cc
+++ b/src/spm_normalize_main.cc
@@ -25,27 +25,34 @@ DEFINE_string(model, "", "Model file name");
DEFINE_bool(use_internal_normalization, false,
"Use NormalizerSpec \"as-is\" to run the normalizer "
"for SentencePiece segmentation");
+DEFINE_string(normalization_rule_name, "",
+ "Normalization rule name. "
+ "Choose from nfkc or identity");
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(remove_extra_whitespaces, true, "Remove extra whitespaces");
DEFINE_string(output, "", "Output filename");
+using sentencepiece::normalizer::Builder;
+
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
sentencepiece::NormalizerSpec spec;
- if (FLAGS_normalization_rule_tsv.empty() && !FLAGS_model.empty()) {
+ if (!FLAGS_model.empty()) {
sentencepiece::SentencePieceProcessor sp;
CHECK_OK(sp.Load(FLAGS_model));
spec = sp.model_proto().normalizer_spec();
- } else if (!FLAGS_normalization_rule_tsv.empty() && FLAGS_model.empty()) {
- const auto cmap = sentencepiece::normalizer::Builder::BuildMapFromFile(
- FLAGS_normalization_rule_tsv);
- spec.set_precompiled_charsmap(
- sentencepiece::normalizer::Builder::CompileCharsMap(cmap));
+ } else if (!FLAGS_normalization_rule_tsv.empty()) {
+ spec.set_normalization_rule_tsv(FLAGS_normalization_rule_tsv);
+ CHECK_OK(Builder::PopulateNormalizationSpec(&spec));
+ } else if (!FLAGS_normalization_rule_name.empty()) {
+ spec.set_name(FLAGS_normalization_rule_name);
+ CHECK_OK(Builder::PopulateNormalizationSpec(&spec));
} else {
- LOG(FATAL) << "Sets --model or normalization_rule_tsv flag";
+ LOG(FATAL) << "Sets --model, normalization_rule_tsv, or "
+ "normalization_rule_name flag.";
}
// Uses the normalizer spec encoded in the model_pb.
diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc
index a88dc71..34caf78 100644
--- a/src/unigram_model_trainer_test.cc
+++ b/src/unigram_model_trainer_test.cc
@@ -35,10 +35,12 @@ TEST(UnigramTrainerTest, TrainerModelTest) {
TEST(UnigramTrainerTest, EndToEndTest) {
TrainerSpec trainer_spec;
- NormalizerSpec normalizer_spec;
- normalizer_spec = normalizer::Builder::GetNormalizerSpec("nfkc");
trainer_spec.add_input("../data/wagahaiwa_nekodearu.txt");
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("identity");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
+
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);
trainer_spec.set_model_type(TrainerSpec::UNIGRAM);
diff --git a/src/word_model_trainer_test.cc b/src/word_model_trainer_test.cc
index 35a6eb5..87b105a 100644
--- a/src/word_model_trainer_test.cc
+++ b/src/word_model_trainer_test.cc
@@ -44,7 +44,9 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
trainer_spec.set_model_prefix(model_prefix);
- auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
+ NormalizerSpec normalizer_spec;
+ normalizer_spec.set_name("identity");
+ EXPECT_OK(normalizer::Builder::PopulateNormalizationSpec(&normalizer_spec));
normalizer_spec.set_add_dummy_prefix(true);
Trainer trainer(trainer_spec, normalizer_spec);