diff options
author | Taku Kudo <taku@google.com> | 2018-06-07 16:44:11 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-07 18:06:42 +0300 |
commit | 54ccef78b800625a58cbdbac1245d77c9b744e84 (patch) | |
tree | 977467b98b2517e20bfa6df17c5f79df92f6106b /src/bpe_model.cc | |
parent | 7edac0b4ef81e94a1fdf041fab03771f943c9643 (diff) |
Support user defined symbols in Char/BPE
Diffstat (limited to 'src/bpe_model.cc')
-rw-r--r-- | src/bpe_model.cc | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc index 6b3fe2e..cbcb51e 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -26,7 +26,7 @@ namespace bpe { Model::Model(const ModelProto &model_proto) { model_proto_ = &model_proto; - InitializePieces(false /* use_user_defined */); + InitializePieces(true /* use prefix matcher */); } Model::~Model() {} @@ -53,8 +53,9 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( }; struct Symbol { - int prev; // prev index of this symbol. -1 for BOS. - int next; // next index of tihs symbol. -1 for EOS. + int prev; // prev index of this symbol. -1 for BOS. + int next; // next index of tihs symbol. -1 for EOS. + bool freeze; // this symbol is never be merged. StringPiece piece; }; @@ -73,7 +74,9 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( // Lookup new symbol pair at [left, right] and inserts it to agenda. auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge]( int left, int right) { - if (left == -1 || right == -1) return; + if (left == -1 || right == -1 || symbols[left].freeze || + symbols[right].freeze) + return; const StringPiece piece( symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size()); @@ -96,20 +99,14 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( }; // Splits the input into character sequence - const char *begin = normalized.data(); - const char *end = normalized.data() + normalized.size(); int index = 0; - while (begin < end) { - int mblen = string_util::OneCharLen(begin); - if (mblen > end - begin) { - LOG(ERROR) << "Invalid character length."; - mblen = end - begin; - } + while (!normalized.empty()) { Symbol s; - s.piece = StringPiece(begin, mblen); - s.prev = begin == normalized.data() ? -1 : index - 1; - begin += mblen; - s.next = begin == end ? -1 : index + 1; + const int mblen = matcher_->PrefixMatch(normalized, &s.freeze); + s.piece = StringPiece(normalized.data(), mblen); + s.prev = index == 0 ? -1 : index - 1; + normalized.remove_prefix(mblen); + s.next = normalized.empty() ? -1 : index + 1; ++index; symbols.emplace_back(s); } |