Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-07 16:44:11 +0300
committerTaku Kudo <taku@google.com>2018-06-07 18:06:42 +0300
commit54ccef78b800625a58cbdbac1245d77c9b744e84 (patch)
tree977467b98b2517e20bfa6df17c5f79df92f6106b /src/bpe_model.cc
parent7edac0b4ef81e94a1fdf041fab03771f943c9643 (diff)
Support user defined symbols in Char/BPE
Diffstat (limited to 'src/bpe_model.cc')
-rw-r--r--src/bpe_model.cc29
1 files changed, 13 insertions, 16 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc
index 6b3fe2e..cbcb51e 100644
--- a/src/bpe_model.cc
+++ b/src/bpe_model.cc
@@ -26,7 +26,7 @@ namespace bpe {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
- InitializePieces(false /* use_user_defined */);
+ InitializePieces(true /* use prefix matcher */);
}
Model::~Model() {}
@@ -53,8 +53,9 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
};
struct Symbol {
- int prev; // prev index of this symbol. -1 for BOS.
- int next; // next index of tihs symbol. -1 for EOS.
+ int prev; // prev index of this symbol. -1 for BOS.
+ int next; // next index of tihs symbol. -1 for EOS.
+ bool freeze; // this symbol is never be merged.
StringPiece piece;
};
@@ -73,7 +74,9 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
// Lookup new symbol pair at [left, right] and inserts it to agenda.
auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge](
int left, int right) {
- if (left == -1 || right == -1) return;
+ if (left == -1 || right == -1 || symbols[left].freeze ||
+ symbols[right].freeze)
+ return;
const StringPiece piece(
symbols[left].piece.data(),
symbols[left].piece.size() + symbols[right].piece.size());
@@ -96,20 +99,14 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
};
// Splits the input into character sequence
- const char *begin = normalized.data();
- const char *end = normalized.data() + normalized.size();
int index = 0;
- while (begin < end) {
- int mblen = string_util::OneCharLen(begin);
- if (mblen > end - begin) {
- LOG(ERROR) << "Invalid character length.";
- mblen = end - begin;
- }
+ while (!normalized.empty()) {
Symbol s;
- s.piece = StringPiece(begin, mblen);
- s.prev = begin == normalized.data() ? -1 : index - 1;
- begin += mblen;
- s.next = begin == end ? -1 : index + 1;
+ const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
+ s.piece = StringPiece(normalized.data(), mblen);
+ s.prev = index == 0 ? -1 : index - 1;
+ normalized.remove_prefix(mblen);
+ s.next = normalized.empty() ? -1 : index + 1;
++index;
symbols.emplace_back(s);
}