diff options
author | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-05-04 22:15:08 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-05-04 22:15:08 +0300 |
commit | dc3dfe8a6c4932ea037d86f35dfd48b16f291d81 (patch) | |
tree | 6a55ca6db939b9102d2e853564bf685697b65ad1 | |
parent | ee72f10e33c5288577c968e25623e88a81982843 (diff) |
merge case mapping with normalization
-rw-r--r-- | src/builder.cc | 17 | ||||
-rw-r--r-- | src/builder.h | 3 | ||||
-rw-r--r-- | src/normalizer.cc | 3 | ||||
-rw-r--r-- | src/sentencepiece_trainer.cc | 14 | ||||
-rw-r--r-- | src/spec_parser.h | 1 |
5 files changed, 38 insertions, 0 deletions
diff --git a/src/builder.cc b/src/builder.cc index 5a1ea6f..8b544ab 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -518,6 +518,23 @@ util::Status Builder::BuildRecaserMap(Builder::CharsMap *chars_map) { } // static +util::Status Builder::ComposeCharsMaps(const Builder::CharsMap &outer_chars_map, Builder::CharsMap *chars_map, bool add_rest) { + for(auto& cp : *chars_map) { + auto found = outer_chars_map.find(cp.second); + if(found != outer_chars_map.end()) + cp.second = found->second; + } + if(add_rest) { + for(auto& cp : outer_chars_map) { + auto found = chars_map->find(cp.first); + if(found == chars_map->end()) + (*chars_map)[cp.first] = cp.second; + } + } + return util::OkStatus(); +} + +// static util::Status Builder::LoadCharsMap(absl::string_view filename, CharsMap *chars_map) { LOG(INFO) << "Loading mapping file: " << filename.data(); diff --git a/src/builder.h b/src/builder.h index d077230..bbd1063 100644 --- a/src/builder.h +++ b/src/builder.h @@ -107,6 +107,9 @@ class Builder { static util::Status BuildUncaserMap(CharsMap *chars_map); static util::Status BuildRecaserMap(CharsMap *chars_map); + // Create composition outer_chars_map(chars_map) into `chars_map`. + static util::Status ComposeCharsMaps(const CharsMap &outer_chars_map, CharsMap *chars_map, bool add_rest); + // Builds Chars map save in `filename`. // Format: // src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2... diff --git a/src/normalizer.cc b/src/normalizer.cc index fb7a0f2..49d6055 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -47,6 +47,7 @@ Normalizer::~Normalizer() {} void Normalizer::Init() { absl::string_view index = spec_->precompiled_charsmap(); + if (index.empty()) { LOG(INFO) << "precompiled_charsmap is empty. use identity normalization."; } else { @@ -183,6 +184,8 @@ util::Status Normalizer::Normalize(absl::string_view input, // Adds a space symbol as a suffix (default is false) if (treat_whitespace_as_suffix_ && spec_->add_dummy_prefix()) add_ws(); + LOG(INFO) << *normalized; + norm_to_orig->push_back(consumed); CHECK_EQ_OR_RETURN(norm_to_orig->size(), normalized->size() + 1); diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index bb4a9c7..ff4716e 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -224,6 +224,20 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec( } } + if(normalizer_spec->encode_case()) { + normalizer::Builder::CharsMap chars_map; + if(!normalizer_spec->precompiled_charsmap().empty()) + normalizer::Builder::DecompileCharsMap(normalizer_spec->precompiled_charsmap(), &chars_map); + + normalizer::Builder::CharsMap uncaser; + std::string precompiledUncaser; + RETURN_IF_ERROR(normalizer::Builder::GetPrecompiledCharsMap("case_uncaser", &precompiledUncaser)); + RETURN_IF_ERROR(normalizer::Builder::DecompileCharsMap(absl::string_view(precompiledUncaser), &uncaser)); + + RETURN_IF_ERROR(normalizer::Builder::ComposeCharsMaps(uncaser, &chars_map, /*add_rest=*/true)); + RETURN_IF_ERROR(normalizer::Builder::CompileCharsMap(chars_map, normalizer_spec->mutable_precompiled_charsmap())); + } + return util::OkStatus(); } diff --git a/src/spec_parser.h b/src/spec_parser.h index 6dd054b..4b4fab7 100644 --- a/src/spec_parser.h +++ b/src/spec_parser.h @@ -169,6 +169,7 @@ inline std::string PrintProto(const NormalizerSpec &message, PRINT_PARAM(remove_extra_whitespaces); PRINT_PARAM(escape_whitespaces); PRINT_PARAM(normalization_rule_tsv); + PRINT_PARAM(encode_case); os << "}\n"; |