Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-06 13:18:53 +0300
committerTaku Kudo <taku@google.com>2018-06-06 13:18:53 +0300
commitaeec94b63ae2e96be762fdfe6f70d934ebd2beb4 (patch)
tree05190bdf102a7f760dbcff3342f0290ab220606b /src/bpe_model.cc
parent93d84a6f3c52fd7407fd8e5e51fb6d1344bf7675 (diff)
Support vocab restriction feature in BPE model.
Diffstat (limited to 'src/bpe_model.cc')
-rw-r--r--src/bpe_model.cc42
1 files changed, 38 insertions, 4 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc
index 2aad9f9..6b3fe2e 100644
--- a/src/bpe_model.cc
+++ b/src/bpe_model.cc
@@ -14,6 +14,7 @@
#include "bpe_model.h"
+#include <functional>
#include <memory>
#include <queue>
#include <utility>
@@ -63,8 +64,15 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
std::vector<Symbol> symbols;
symbols.reserve(normalized.size());
+ // Reverse merge rules.
+ // key: merged symbol, value: pair of original symbols.
+ std::unordered_map<StringPiece, std::pair<StringPiece, StringPiece>,
+ StringPieceHash>
+ rev_merge;
+
// Lookup new symbol pair at [left, right] and inserts it to agenda.
- auto MaybeAddNewSymbolPair = [this, &symbols, &agenda](int left, int right) {
+ auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge](
+ int left, int right) {
if (left == -1 || right == -1) return;
const StringPiece piece(
symbols[left].piece.data(),
@@ -79,6 +87,12 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
h->score = GetScore(it->second);
h->size = piece.size();
agenda.push(h);
+
+ // Makes `rev_merge` for resegmentation.
+ if (IsUnused(it->second)) {
+ rev_merge[piece] =
+ std::make_pair(symbols[left].piece, symbols[right].piece);
+ }
};
// Splits the input into character sequence
@@ -114,14 +128,14 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
std::unique_ptr<SymbolPair> top(agenda.top());
agenda.pop();
- // |top| is no longer available.
+ // `top` is no longer available.
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
top->size) {
continue;
}
- // Replaces symbols with |top| rule.
+ // Replaces symbols with `top` rule.
symbols[top->left].piece = StringPiece(
symbols[top->left].piece.data(),
symbols[top->left].piece.size() + symbols[top->right].piece.size());
@@ -138,11 +152,31 @@ std::vector<std::pair<StringPiece, int>> Model::Encode(
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
}
+ std::function<void(StringPiece, EncodeResult *)> resegment;
+ resegment = [this, &resegment, &rev_merge](StringPiece w,
+ EncodeResult *output) -> void {
+ const int id = PieceToId(w);
+ if (id == -1 || !IsUnused(id)) {
+ output->emplace_back(w, id);
+ return;
+ }
+ const auto p = rev_merge.find(w);
+ if (p == rev_merge.end()) {
+ // This block will never be called, as `rev_merge` stores all the
+ // resegmentation info for unused id.
+ output->emplace_back(w, id);
+ return;
+ }
+ // Recursively resegment left and right symbols.
+ resegment(p->second.first, output);
+ resegment(p->second.second, output);
+ };
+
EncodeResult output;
for (int index = 0; index != -1; index = symbols[index].next) {
CHECK_GE(index, 0);
CHECK_LT(index, static_cast<int>(symbols.size()));
- output.emplace_back(symbols[index].piece, PieceToId(symbols[index].piece));
+ resegment(symbols[index].piece, &output);
}
return output;