Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-05-04 22:15:08 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-05-04 22:15:08 +0300
commitdc3dfe8a6c4932ea037d86f35dfd48b16f291d81 (patch)
tree6a55ca6db939b9102d2e853564bf685697b65ad1
parentee72f10e33c5288577c968e25623e88a81982843 (diff)
merge case mapping with normalization
-rw-r--r--src/builder.cc17
-rw-r--r--src/builder.h3
-rw-r--r--src/normalizer.cc3
-rw-r--r--src/sentencepiece_trainer.cc14
-rw-r--r--src/spec_parser.h1
5 files changed, 38 insertions, 0 deletions
diff --git a/src/builder.cc b/src/builder.cc
index 5a1ea6f..8b544ab 100644
--- a/src/builder.cc
+++ b/src/builder.cc
@@ -518,6 +518,23 @@ util::Status Builder::BuildRecaserMap(Builder::CharsMap *chars_map) {
}
// static
+util::Status Builder::ComposeCharsMaps(const Builder::CharsMap &outer_chars_map, Builder::CharsMap *chars_map, bool add_rest) {
+ for(auto& cp : *chars_map) {
+ auto found = outer_chars_map.find(cp.second);
+ if(found != outer_chars_map.end())
+ cp.second = found->second;
+ }
+ if(add_rest) {
+ for(auto& cp : outer_chars_map) {
+ auto found = chars_map->find(cp.first);
+ if(found == chars_map->end())
+ (*chars_map)[cp.first] = cp.second;
+ }
+ }
+ return util::OkStatus();
+}
+
+// static
util::Status Builder::LoadCharsMap(absl::string_view filename,
CharsMap *chars_map) {
LOG(INFO) << "Loading mapping file: " << filename.data();
diff --git a/src/builder.h b/src/builder.h
index d077230..bbd1063 100644
--- a/src/builder.h
+++ b/src/builder.h
@@ -107,6 +107,9 @@ class Builder {
static util::Status BuildUncaserMap(CharsMap *chars_map);
static util::Status BuildRecaserMap(CharsMap *chars_map);
+ // Create composition outer_chars_map(chars_map) into `chars_map`.
+ static util::Status ComposeCharsMaps(const CharsMap &outer_chars_map, CharsMap *chars_map, bool add_rest);
+
// Builds Chars map save in `filename`.
// Format:
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
diff --git a/src/normalizer.cc b/src/normalizer.cc
index fb7a0f2..49d6055 100644
--- a/src/normalizer.cc
+++ b/src/normalizer.cc
@@ -47,6 +47,7 @@ Normalizer::~Normalizer() {}
void Normalizer::Init() {
absl::string_view index = spec_->precompiled_charsmap();
+
if (index.empty()) {
LOG(INFO) << "precompiled_charsmap is empty. use identity normalization.";
} else {
@@ -183,6 +184,8 @@ util::Status Normalizer::Normalize(absl::string_view input,
// Adds a space symbol as a suffix (default is false)
if (treat_whitespace_as_suffix_ && spec_->add_dummy_prefix()) add_ws();
+ LOG(INFO) << *normalized;
+
norm_to_orig->push_back(consumed);
CHECK_EQ_OR_RETURN(norm_to_orig->size(), normalized->size() + 1);
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index bb4a9c7..ff4716e 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -224,6 +224,20 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec(
}
}
+ if(normalizer_spec->encode_case()) {
+ normalizer::Builder::CharsMap chars_map;
+ if(!normalizer_spec->precompiled_charsmap().empty())
+ normalizer::Builder::DecompileCharsMap(normalizer_spec->precompiled_charsmap(), &chars_map);
+
+ normalizer::Builder::CharsMap uncaser;
+ std::string precompiledUncaser;
+ RETURN_IF_ERROR(normalizer::Builder::GetPrecompiledCharsMap("case_uncaser", &precompiledUncaser));
+ RETURN_IF_ERROR(normalizer::Builder::DecompileCharsMap(absl::string_view(precompiledUncaser), &uncaser));
+
+ RETURN_IF_ERROR(normalizer::Builder::ComposeCharsMaps(uncaser, &chars_map, /*add_rest=*/true));
+ RETURN_IF_ERROR(normalizer::Builder::CompileCharsMap(chars_map, normalizer_spec->mutable_precompiled_charsmap()));
+ }
+
return util::OkStatus();
}
diff --git a/src/spec_parser.h b/src/spec_parser.h
index 6dd054b..4b4fab7 100644
--- a/src/spec_parser.h
+++ b/src/spec_parser.h
@@ -169,6 +169,7 @@ inline std::string PrintProto(const NormalizerSpec &message,
PRINT_PARAM(remove_extra_whitespaces);
PRINT_PARAM(escape_whitespaces);
PRINT_PARAM(normalization_rule_tsv);
+ PRINT_PARAM(encode_case);
os << "}\n";