Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-05-05 09:39:15 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-05-05 09:39:15 +0300
commit5c5483a82d8edcaf2d7f9c76e9977f07c4880e4a (patch)
treecdfa1c75669bc965fe681040941ed28f2704974b
parent0b1208082fc0694bd51679ad32f1718793ca0e27 (diff)
rewrite case normalizer
-rw-r--r--src/case_encoder.h186
-rw-r--r--src/normalizer.cc80
2 files changed, 130 insertions, 136 deletions
diff --git a/src/case_encoder.h b/src/case_encoder.h
index 2e0a04f..afc0dca 100644
--- a/src/case_encoder.h
+++ b/src/case_encoder.h
@@ -19,7 +19,7 @@
#include <set>
#include <string>
#include <utility>
-#include <vector>
+#include <deque>
#include "common.h"
#include "third_party/absl/strings/string_view.h"
@@ -31,125 +31,141 @@ namespace normalizer {
class CaseEncoder {
public:
virtual ~CaseEncoder() {}
- virtual bool encode(const absl::string_view& sp, int n, int src, int consumed) = 0;
+ virtual void push(const std::pair<absl::string_view, int>& p, bool last) = 0;
+ virtual bool empty() = 0;
+ virtual std::pair<absl::string_view, int> pop() = 0;
- static std::unique_ptr<CaseEncoder> Create(bool, bool, absl::string_view* /*input*/, std::string* /*normalized*/, std::vector<size_t>* /*norm_to_orig*/);
+ static std::unique_ptr<CaseEncoder> Create(bool, bool);
};
class IdentityCaseEncoder : public CaseEncoder {
+private:
+ std::pair<absl::string_view, int> p_;
+ bool empty_{true};
+
public:
IdentityCaseEncoder() {}
- bool encode(const absl::string_view& sp, int n, int src, int consumed) {
- return true;
+ void push(const std::pair<absl::string_view, int>& p, bool /*last*/) {
+ p_ = p;
+ empty_ = false;
+ }
+
+ bool empty() {
+ return empty_;
+ }
+
+ std::pair<absl::string_view, int> pop() {
+ empty_ = true;
+ return p_;
}
};
class UpperCaseEncoder : public CaseEncoder {
-private:
- char* last_u_{nullptr};
- size_t last_u_dist_{0};
- std::string* normalized_;
- std::vector<size_t> *norm_to_orig_;
-
-public:
- UpperCaseEncoder(std::string* normalized, std::vector<size_t> *norm_to_orig)
- : normalized_(normalized), norm_to_orig_(norm_to_orig) {}
-
- bool encode(const absl::string_view& sp, int n, int src, int consumed) {
- if(n != 0)
- return true;
-
- char curChar = sp.data()[0];
- if(curChar == ' ') {
- if(last_u_) {
- if(last_u_dist_ == 1)
- *last_u_ = 'T';
- last_u_ = nullptr;
- last_u_dist_ = 0;
+ std::vector<std::string> buffers_;
+ std::deque<std::pair<absl::string_view, int>> pieces_;
+ bool flush_{false};
+ size_t countU_{0};
+
+ void fixUs() {
+ if(countU_ == 1) {
+ auto sp = pieces_.front().first;
+ buffers_.emplace_back(sp.data(), sp.size());
+ buffers_.back()[0] = 'T';
+ pieces_.front().first = absl::string_view(buffers_.back());
+ } else if(countU_ > 1) {
+ for(int i = 1; i < countU_; ++i) {
+ auto sp = pieces_[i].first;
+ pieces_[i].first = absl::string_view(sp.data() + 1, sp.size() - 1);
}
- return true;
}
+ }
- if(!last_u_ && curChar == 'U') {
- last_u_ = &(*normalized_)[0] + normalized_->size();
- last_u_dist_++;
- } else if(last_u_ && curChar == 'U') { // uppercase sequence, skip over U
- last_u_dist_++;
- return false;
- } else if (last_u_ && curChar != 'U' && last_u_dist_ == 1) { // single uppercase letter
- *last_u_ = 'T';
- last_u_ = nullptr;
- last_u_dist_ = 0;
- } else if (last_u_ && (curChar != 'U' && curChar != 'P') && last_u_dist_ > 1) {
- // we had a longer uppercase sequence, hence need to insert 'L'
- normalized_->append(1, 'L');
- norm_to_orig_->push_back(consumed);
-
- last_u_ = nullptr;
- last_u_dist_ = 0;
+public:
+ UpperCaseEncoder() {}
+
+ void push(const std::pair<absl::string_view, int>& p, bool last) {
+ auto sp = p.first;
+ if(sp.data()[0] == 'U') {
+ pieces_.push_back(p);
+ countU_++;
+ flush_ = false;
+ } else if(sp.data()[0] == 'P') {
+ fixUs();
+ pieces_.push_back({absl::string_view(sp.data() + 1, sp.size() - 1), p.second});
+ countU_ = 0;
+ flush_ = true;
+ } else if(sp.data()[0] == ' ') {
+ fixUs();
+ pieces_.push_back(p);
+ countU_ = 0;
+ flush_ = true;
} else {
- last_u_ = nullptr;
- last_u_dist_ = 0;
+ fixUs();
+ if(countU_ > 1) {
+ buffers_.emplace_back("L");
+ buffers_.back().append(p.first.data(), p.first.size());
+ pieces_.push_back({buffers_.back(), p.second});
+ } else {
+ pieces_.push_back(p);
+ }
+ countU_ = 0;
+ flush_ = true;
}
- if(curChar == 'P') {
- last_u_ = nullptr;
- last_u_dist_ = 0;
- return false;
- }
+ if(last)
+ flush_ = true; // flush it all out
+ }
+
+ bool empty() {
+ return pieces_.empty() || !flush_;
+ }
- return true;
+ std::pair<absl::string_view, int> pop() {
+ auto p = pieces_.front();
+ pieces_.pop_front();
+ return p;
}
};
class UpperCaseDecoder : public CaseEncoder {
-private:
- char* last_u_{nullptr};
- size_t last_u_dist_{0};
- std::string* normalized_;
- std::vector<size_t> *norm_to_orig_;
-
- std::string buffer_;
- absl::string_view* input_;
+ std::vector<std::string> buffers_;
+ std::deque<std::pair<absl::string_view, int>> pieces_;
+ bool flush_{false};
public:
- UpperCaseDecoder(absl::string_view* input, std::string* normalized, std::vector<size_t> *norm_to_orig)
- : normalized_(normalized), norm_to_orig_(norm_to_orig), buffer_(input->data(), input->size()), input_(input) {
- *input = absl::string_view(buffer_);
-
- // if(buffer_[0] == 'T')
- // buffer_[0] = 'U';
- }
+ UpperCaseDecoder() {}
- bool encode(const absl::string_view& sp, int n, int src, int consumed) {
- if(n != 0)
- return true;
+ void push(const std::pair<absl::string_view, int>& p, bool last) {
+ auto sp = p.first;
- std::cerr << "B: " << n << " " << src << " " << consumed << " " << std::string(input_->data(), input_->size()) << std::endl;
+ std::cerr << p.first << std::endl;
+ pieces_.push_back(p);
+ flush_ = true;
- // if(consumed + src < buffer_.size() && buffer_[consumed + src] == 'T') {
- // buffer_[consumed + src] = 'U';
- // }
-
- if(consumed + src < buffer_.size() && input_->data()[0] == 'U') {
- buffer_[consumed + src - 1] = 'U';
- *input_ = absl::string_view(input_->data() - 1, input_->size() + 1);
- }
+ if(last)
+ flush_ = true; // flush it all out
+ }
- std::cerr << "A: " << n << " " << src << " " << consumed << " " << std::string(input_->data(), input_->size()) << std::endl;
- return true;
+ bool empty() {
+ return pieces_.empty() || !flush_;
+ }
+
+ std::pair<absl::string_view, int> pop() {
+ auto p = pieces_.front();
+ pieces_.pop_front();
+ return p;
}
};
-std::unique_ptr<CaseEncoder> CaseEncoder::Create(bool encodeCase, bool decodeCase, absl::string_view* input, std::string* normalized, std::vector<size_t>* norm_to_orig) {
+std::unique_ptr<CaseEncoder> CaseEncoder::Create(bool encodeCase, bool decodeCase) {
+ // LOG(INFO) << encodeCase << " " << decodeCase;
if(encodeCase && decodeCase) {
LOG(ERROR) << "Cannot set both encodeCase=true and decodeCase=true";
return nullptr;
} else if(encodeCase) {
- return std::unique_ptr<CaseEncoder>(new UpperCaseEncoder(normalized, norm_to_orig));
+ return std::unique_ptr<CaseEncoder>(new UpperCaseEncoder());
} else if(decodeCase) {
- return std::unique_ptr<CaseEncoder>(new IdentityCaseEncoder());
- // return std::unique_ptr<CaseEncoder>(new UpperCaseDecoder(input, normalized, norm_to_orig));
+ return std::unique_ptr<CaseEncoder>(new UpperCaseDecoder());
} else {
return std::unique_ptr<CaseEncoder>(new IdentityCaseEncoder());
}
diff --git a/src/normalizer.cc b/src/normalizer.cc
index ced56c4..c11db4a 100644
--- a/src/normalizer.cc
+++ b/src/normalizer.cc
@@ -77,11 +77,6 @@ util::Status Normalizer::Normalize(absl::string_view input,
std::string *normalized,
std::vector<size_t> *norm_to_orig) const {
- std::string buffer(input.data(), input.size());
- input = absl::string_view(buffer);
-
- std::cerr << "In: " << buffer << std::endl;
-
norm_to_orig->clear();
normalized->clear();
@@ -139,71 +134,54 @@ util::Status Normalizer::Normalize(absl::string_view input,
if (!treat_whitespace_as_suffix_ && spec_->add_dummy_prefix())
add_ws();
- std::unique_ptr<CaseEncoder> case_encoder = CaseEncoder::Create(spec_->encode_case(), spec_->decode_case(), &input, normalized, norm_to_orig);
- int correction = 0;
- bool shifted = false;
+ std::unique_ptr<CaseEncoder> case_encoder = CaseEncoder::Create(spec_->encode_case(), spec_->decode_case());
+
+ ////////////////////////////////////////////////////////////////////////////////////
bool is_prev_space = spec_->remove_extra_whitespaces();
while (!input.empty()) {
- correction = shifted ? 1 : 0;
-
+ // auto p = case_encoder->normalizePrefix(input);
auto p = NormalizePrefix(input);
- absl::string_view sp = p.first;
+ int sp_consumed = p.second;
- // Removes heading spaces in sentence piece,
- // if the previous sentence piece ends with whitespace.
- while (is_prev_space && absl::ConsumePrefix(&sp, " ")) {}
+ case_encoder->push(p, /*last=*/input.size() == sp_consumed);
- if (!sp.empty()) {
- for (size_t n = 0; n < sp.size(); ++n) {
- if (spec_->escape_whitespaces() && sp.data()[n] == ' ') {
- bool append = case_encoder->encode(sp, n, p.second, consumed);
- if(append) {
+ while(!case_encoder->empty()) {
+ absl::string_view sp = case_encoder->pop().first;
+
+ // Removes heading spaces in sentence piece,
+ // if the previous sentence piece ends with whitespace.
+ while (is_prev_space && absl::ConsumePrefix(&sp, " ")) {}
+
+ if (!sp.empty()) {
+ for (size_t n = 0; n < sp.size(); ++n) {
+ if (spec_->escape_whitespaces() && sp.data()[n] == ' ') {
// replace ' ' with kSpaceSymbol.
normalized->append(kSpaceSymbol.data(), kSpaceSymbol.size());
for (size_t m = 0; m < kSpaceSymbol.size(); ++m) {
norm_to_orig->push_back(consumed);
}
- }
- } else {
- if(n == 0) {
- if(input.data()[0] == 'U' && p.second > 1) {
- int index = (input.data() + p.second - 1) - buffer.data();
- buffer[index] = 'U';
- std::cerr << std::string(sp.data(), sp.size()) << " " << p.second << " -> " << (buffer.data() + index) << std::endl;
- input = absl::string_view(input.data() - 1, input.size() + 1);
- shifted = true;
- } else if(input.data()[0] == 'U' && p.second == 1) {
- shifted = false;
- continue;
- } else if(input.data()[0] == 'L' && p.second == 1) {
- shifted = false;
- continue;
- } else {
- shifted = false;
- }
- }
-
- bool append = case_encoder->encode(sp, n, p.second, consumed);
- if(append) {
- normalized->append(sp.data() + n, 1);
- norm_to_orig->push_back(consumed);
+ } else {
+ normalized->append(sp.data() + n, 1);
+ norm_to_orig->push_back(consumed);
}
}
+
+ // Checks whether the last character of sp is whitespace.
+ is_prev_space = absl::EndsWith(sp, " ");
}
- // Checks whether the last character of sp is whitespace.
- is_prev_space = absl::EndsWith(sp, " ");
+ if (!spec_->remove_extra_whitespaces()) {
+ is_prev_space = false;
+ }
}
- consumed += p.second - correction;
- std::cerr << "consumed: " << consumed << " -> " << (buffer.data() + consumed) << std::endl;
- input.remove_prefix(p.second);
- if (!spec_->remove_extra_whitespaces()) {
- is_prev_space = false;
- }
+ consumed += sp_consumed;
+ input.remove_prefix(sp_consumed);
}
+ ////////////////////////////////////////////////////////////////////////////////////
+
// Ignores tailing space.
if (spec_->remove_extra_whitespaces()) {
const absl::string_view space =