diff options
author | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-05-05 09:39:15 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-05-05 09:39:15 +0300 |
commit | 5c5483a82d8edcaf2d7f9c76e9977f07c4880e4a (patch) | |
tree | cdfa1c75669bc965fe681040941ed28f2704974b | |
parent | 0b1208082fc0694bd51679ad32f1718793ca0e27 (diff) |
rewrite case normalizer
-rw-r--r-- | src/case_encoder.h | 186 | ||||
-rw-r--r-- | src/normalizer.cc | 80 |
2 files changed, 130 insertions, 136 deletions
diff --git a/src/case_encoder.h b/src/case_encoder.h index 2e0a04f..afc0dca 100644 --- a/src/case_encoder.h +++ b/src/case_encoder.h @@ -19,7 +19,7 @@ #include <set> #include <string> #include <utility> -#include <vector> +#include <deque> #include "common.h" #include "third_party/absl/strings/string_view.h" @@ -31,125 +31,141 @@ namespace normalizer { class CaseEncoder { public: virtual ~CaseEncoder() {} - virtual bool encode(const absl::string_view& sp, int n, int src, int consumed) = 0; + virtual void push(const std::pair<absl::string_view, int>& p, bool last) = 0; + virtual bool empty() = 0; + virtual std::pair<absl::string_view, int> pop() = 0; - static std::unique_ptr<CaseEncoder> Create(bool, bool, absl::string_view* /*input*/, std::string* /*normalized*/, std::vector<size_t>* /*norm_to_orig*/); + static std::unique_ptr<CaseEncoder> Create(bool, bool); }; class IdentityCaseEncoder : public CaseEncoder { +private: + std::pair<absl::string_view, int> p_; + bool empty_{true}; + public: IdentityCaseEncoder() {} - bool encode(const absl::string_view& sp, int n, int src, int consumed) { - return true; + void push(const std::pair<absl::string_view, int>& p, bool /*last*/) { + p_ = p; + empty_ = false; + } + + bool empty() { + return empty_; + } + + std::pair<absl::string_view, int> pop() { + empty_ = true; + return p_; } }; class UpperCaseEncoder : public CaseEncoder { -private: - char* last_u_{nullptr}; - size_t last_u_dist_{0}; - std::string* normalized_; - std::vector<size_t> *norm_to_orig_; - -public: - UpperCaseEncoder(std::string* normalized, std::vector<size_t> *norm_to_orig) - : normalized_(normalized), norm_to_orig_(norm_to_orig) {} - - bool encode(const absl::string_view& sp, int n, int src, int consumed) { - if(n != 0) - return true; - - char curChar = sp.data()[0]; - if(curChar == ' ') { - if(last_u_) { - if(last_u_dist_ == 1) - *last_u_ = 'T'; - last_u_ = nullptr; - last_u_dist_ = 0; + std::vector<std::string> buffers_; + std::deque<std::pair<absl::string_view, int>> pieces_; + bool flush_{false}; + size_t countU_{0}; + + void fixUs() { + if(countU_ == 1) { + auto sp = pieces_.front().first; + buffers_.emplace_back(sp.data(), sp.size()); + buffers_.back()[0] = 'T'; + pieces_.front().first = absl::string_view(buffers_.back()); + } else if(countU_ > 1) { + for(int i = 1; i < countU_; ++i) { + auto sp = pieces_[i].first; + pieces_[i].first = absl::string_view(sp.data() + 1, sp.size() - 1); } - return true; } + } - if(!last_u_ && curChar == 'U') { - last_u_ = &(*normalized_)[0] + normalized_->size(); - last_u_dist_++; - } else if(last_u_ && curChar == 'U') { // uppercase sequence, skip over U - last_u_dist_++; - return false; - } else if (last_u_ && curChar != 'U' && last_u_dist_ == 1) { // single uppercase letter - *last_u_ = 'T'; - last_u_ = nullptr; - last_u_dist_ = 0; - } else if (last_u_ && (curChar != 'U' && curChar != 'P') && last_u_dist_ > 1) { - // we had a longer uppercase sequence, hence need to insert 'L' - normalized_->append(1, 'L'); - norm_to_orig_->push_back(consumed); - - last_u_ = nullptr; - last_u_dist_ = 0; +public: + UpperCaseEncoder() {} + + void push(const std::pair<absl::string_view, int>& p, bool last) { + auto sp = p.first; + if(sp.data()[0] == 'U') { + pieces_.push_back(p); + countU_++; + flush_ = false; + } else if(sp.data()[0] == 'P') { + fixUs(); + pieces_.push_back({absl::string_view(sp.data() + 1, sp.size() - 1), p.second}); + countU_ = 0; + flush_ = true; + } else if(sp.data()[0] == ' ') { + fixUs(); + pieces_.push_back(p); + countU_ = 0; + flush_ = true; } else { - last_u_ = nullptr; - last_u_dist_ = 0; + fixUs(); + if(countU_ > 1) { + buffers_.emplace_back("L"); + buffers_.back().append(p.first.data(), p.first.size()); + pieces_.push_back({buffers_.back(), p.second}); + } else { + pieces_.push_back(p); + } + countU_ = 0; + flush_ = true; } - if(curChar == 'P') { - last_u_ = nullptr; - last_u_dist_ = 0; - return false; - } + if(last) + flush_ = true; // flush it all out + } + + bool empty() { + return pieces_.empty() || !flush_; + } - return true; + std::pair<absl::string_view, int> pop() { + auto p = pieces_.front(); + pieces_.pop_front(); + return p; } }; class UpperCaseDecoder : public CaseEncoder { -private: - char* last_u_{nullptr}; - size_t last_u_dist_{0}; - std::string* normalized_; - std::vector<size_t> *norm_to_orig_; - - std::string buffer_; - absl::string_view* input_; + std::vector<std::string> buffers_; + std::deque<std::pair<absl::string_view, int>> pieces_; + bool flush_{false}; public: - UpperCaseDecoder(absl::string_view* input, std::string* normalized, std::vector<size_t> *norm_to_orig) - : normalized_(normalized), norm_to_orig_(norm_to_orig), buffer_(input->data(), input->size()), input_(input) { - *input = absl::string_view(buffer_); - - // if(buffer_[0] == 'T') - // buffer_[0] = 'U'; - } + UpperCaseDecoder() {} - bool encode(const absl::string_view& sp, int n, int src, int consumed) { - if(n != 0) - return true; + void push(const std::pair<absl::string_view, int>& p, bool last) { + auto sp = p.first; - std::cerr << "B: " << n << " " << src << " " << consumed << " " << std::string(input_->data(), input_->size()) << std::endl; + std::cerr << p.first << std::endl; + pieces_.push_back(p); + flush_ = true; - // if(consumed + src < buffer_.size() && buffer_[consumed + src] == 'T') { - // buffer_[consumed + src] = 'U'; - // } - - if(consumed + src < buffer_.size() && input_->data()[0] == 'U') { - buffer_[consumed + src - 1] = 'U'; - *input_ = absl::string_view(input_->data() - 1, input_->size() + 1); - } + if(last) + flush_ = true; // flush it all out + } - std::cerr << "A: " << n << " " << src << " " << consumed << " " << std::string(input_->data(), input_->size()) << std::endl; - return true; + bool empty() { + return pieces_.empty() || !flush_; + } + + std::pair<absl::string_view, int> pop() { + auto p = pieces_.front(); + pieces_.pop_front(); + return p; } }; -std::unique_ptr<CaseEncoder> CaseEncoder::Create(bool encodeCase, bool decodeCase, absl::string_view* input, std::string* normalized, std::vector<size_t>* norm_to_orig) { +std::unique_ptr<CaseEncoder> CaseEncoder::Create(bool encodeCase, bool decodeCase) { + // LOG(INFO) << encodeCase << " " << decodeCase; if(encodeCase && decodeCase) { LOG(ERROR) << "Cannot set both encodeCase=true and decodeCase=true"; return nullptr; } else if(encodeCase) { - return std::unique_ptr<CaseEncoder>(new UpperCaseEncoder(normalized, norm_to_orig)); + return std::unique_ptr<CaseEncoder>(new UpperCaseEncoder()); } else if(decodeCase) { - return std::unique_ptr<CaseEncoder>(new IdentityCaseEncoder()); - // return std::unique_ptr<CaseEncoder>(new UpperCaseDecoder(input, normalized, norm_to_orig)); + return std::unique_ptr<CaseEncoder>(new UpperCaseDecoder()); } else { return std::unique_ptr<CaseEncoder>(new IdentityCaseEncoder()); } diff --git a/src/normalizer.cc b/src/normalizer.cc index ced56c4..c11db4a 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -77,11 +77,6 @@ util::Status Normalizer::Normalize(absl::string_view input, std::string *normalized, std::vector<size_t> *norm_to_orig) const { - std::string buffer(input.data(), input.size()); - input = absl::string_view(buffer); - - std::cerr << "In: " << buffer << std::endl; - norm_to_orig->clear(); normalized->clear(); @@ -139,71 +134,54 @@ util::Status Normalizer::Normalize(absl::string_view input, if (!treat_whitespace_as_suffix_ && spec_->add_dummy_prefix()) add_ws(); - std::unique_ptr<CaseEncoder> case_encoder = CaseEncoder::Create(spec_->encode_case(), spec_->decode_case(), &input, normalized, norm_to_orig); - int correction = 0; - bool shifted = false; + std::unique_ptr<CaseEncoder> case_encoder = CaseEncoder::Create(spec_->encode_case(), spec_->decode_case()); + + //////////////////////////////////////////////////////////////////////////////////// bool is_prev_space = spec_->remove_extra_whitespaces(); while (!input.empty()) { - correction = shifted ? 1 : 0; - + // auto p = case_encoder->normalizePrefix(input); auto p = NormalizePrefix(input); - absl::string_view sp = p.first; + int sp_consumed = p.second; - // Removes heading spaces in sentence piece, - // if the previous sentence piece ends with whitespace. - while (is_prev_space && absl::ConsumePrefix(&sp, " ")) {} + case_encoder->push(p, /*last=*/input.size() == sp_consumed); - if (!sp.empty()) { - for (size_t n = 0; n < sp.size(); ++n) { - if (spec_->escape_whitespaces() && sp.data()[n] == ' ') { - bool append = case_encoder->encode(sp, n, p.second, consumed); - if(append) { + while(!case_encoder->empty()) { + absl::string_view sp = case_encoder->pop().first; + + // Removes heading spaces in sentence piece, + // if the previous sentence piece ends with whitespace. + while (is_prev_space && absl::ConsumePrefix(&sp, " ")) {} + + if (!sp.empty()) { + for (size_t n = 0; n < sp.size(); ++n) { + if (spec_->escape_whitespaces() && sp.data()[n] == ' ') { // replace ' ' with kSpaceSymbol. normalized->append(kSpaceSymbol.data(), kSpaceSymbol.size()); for (size_t m = 0; m < kSpaceSymbol.size(); ++m) { norm_to_orig->push_back(consumed); } - } - } else { - if(n == 0) { - if(input.data()[0] == 'U' && p.second > 1) { - int index = (input.data() + p.second - 1) - buffer.data(); - buffer[index] = 'U'; - std::cerr << std::string(sp.data(), sp.size()) << " " << p.second << " -> " << (buffer.data() + index) << std::endl; - input = absl::string_view(input.data() - 1, input.size() + 1); - shifted = true; - } else if(input.data()[0] == 'U' && p.second == 1) { - shifted = false; - continue; - } else if(input.data()[0] == 'L' && p.second == 1) { - shifted = false; - continue; - } else { - shifted = false; - } - } - - bool append = case_encoder->encode(sp, n, p.second, consumed); - if(append) { - normalized->append(sp.data() + n, 1); - norm_to_orig->push_back(consumed); + } else { + normalized->append(sp.data() + n, 1); + norm_to_orig->push_back(consumed); } } + + // Checks whether the last character of sp is whitespace. + is_prev_space = absl::EndsWith(sp, " "); } - // Checks whether the last character of sp is whitespace. - is_prev_space = absl::EndsWith(sp, " "); + if (!spec_->remove_extra_whitespaces()) { + is_prev_space = false; + } } - consumed += p.second - correction; - std::cerr << "consumed: " << consumed << " -> " << (buffer.data() + consumed) << std::endl; - input.remove_prefix(p.second); - if (!spec_->remove_extra_whitespaces()) { - is_prev_space = false; - } + consumed += sp_consumed; + input.remove_prefix(sp_consumed); } + //////////////////////////////////////////////////////////////////////////////////// + // Ignores tailing space. if (spec_->remove_extra_whitespaces()) { const absl::string_view space = |