diff options
author | Taku Kudo <taku@google.com> | 2018-06-06 13:18:53 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-06 13:18:53 +0300 |
commit | aeec94b63ae2e96be762fdfe6f70d934ebd2beb4 (patch) | |
tree | 05190bdf102a7f760dbcff3342f0290ab220606b /src/bpe_model.cc | |
parent | 93d84a6f3c52fd7407fd8e5e51fb6d1344bf7675 (diff) |
Support vocab restriction feature in BPE model.
Diffstat (limited to 'src/bpe_model.cc')
-rw-r--r-- | src/bpe_model.cc | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc index 2aad9f9..6b3fe2e 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -14,6 +14,7 @@ #include "bpe_model.h" +#include <functional> #include <memory> #include <queue> #include <utility> @@ -63,8 +64,15 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( std::vector<Symbol> symbols; symbols.reserve(normalized.size()); + // Reverse merge rules. + // key: merged symbol, value: pair of original symbols. + std::unordered_map<StringPiece, std::pair<StringPiece, StringPiece>, + StringPieceHash> + rev_merge; + // Lookup new symbol pair at [left, right] and inserts it to agenda. - auto MaybeAddNewSymbolPair = [this, &symbols, &agenda](int left, int right) { + auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge]( + int left, int right) { if (left == -1 || right == -1) return; const StringPiece piece( symbols[left].piece.data(), @@ -79,6 +87,12 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( h->score = GetScore(it->second); h->size = piece.size(); agenda.push(h); + + // Makes `rev_merge` for resegmentation. + if (IsUnused(it->second)) { + rev_merge[piece] = + std::make_pair(symbols[left].piece, symbols[right].piece); + } }; // Splits the input into character sequence @@ -114,14 +128,14 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( std::unique_ptr<SymbolPair> top(agenda.top()); agenda.pop(); - // |top| is no longer available. + // `top` is no longer available. if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() || symbols[top->left].piece.size() + symbols[top->right].piece.size() != top->size) { continue; } - // Replaces symbols with |top| rule. + // Replaces symbols with `top` rule. symbols[top->left].piece = StringPiece( symbols[top->left].piece.data(), symbols[top->left].piece.size() + symbols[top->right].piece.size()); @@ -138,11 +152,31 @@ std::vector<std::pair<StringPiece, int>> Model::Encode( MaybeAddNewSymbolPair(top->left, symbols[top->left].next); } + std::function<void(StringPiece, EncodeResult *)> resegment; + resegment = [this, &resegment, &rev_merge](StringPiece w, + EncodeResult *output) -> void { + const int id = PieceToId(w); + if (id == -1 || !IsUnused(id)) { + output->emplace_back(w, id); + return; + } + const auto p = rev_merge.find(w); + if (p == rev_merge.end()) { + // This block will never be called, as `rev_merge` stores all the + // resegmentation info for unused id. + output->emplace_back(w, id); + return; + } + // Recursively resegment left and right symbols. + resegment(p->second.first, output); + resegment(p->second.second, output); + }; + EncodeResult output; for (int index = 0; index != -1; index = symbols[index].next) { CHECK_GE(index, 0); CHECK_LT(index, static_cast<int>(symbols.size())); - output.emplace_back(symbols[index].piece, PieceToId(symbols[index].piece)); + resegment(symbols[index].piece, &output); } return output; |