Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2017-03-07 13:43:50 +0300
committerTaku Kudo <taku@google.com>2017-03-07 13:43:50 +0300
commit2928ce5307224ea4c012fc6cbd7a098c486590b6 (patch)
tree38b679886855a7a6b80fdc61f2f62c952cf3bfb7 /src/bpe_model.cc
Initialize repository
Diffstat (limited to 'src/bpe_model.cc')
-rw-r--r--src/bpe_model.cc159
1 files changed, 159 insertions, 0 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc
new file mode 100644
index 0000000..d226b0d
--- /dev/null
+++ b/src/bpe_model.cc
@@ -0,0 +1,159 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#include "bpe_model.h"
+
+#include <queue>
+#include "util.h"
+
+namespace sentencepiece {
+namespace bpe {
+
+Model::Model(const ModelProto &model_proto) {
+ model_proto_ = &model_proto;
+ CheckControlSymbols();
+
+ for (int i = 0; i < model_proto_->pieces_size(); ++i) {
+ const auto &sp = model_proto_->pieces(i);
+ CHECK(!sp.piece().empty());
+ if (sp.type() == ModelProto::SentencePiece::NORMAL) {
+ CHECK(sp.has_score());
+ port::InsertOrDie(&pieces_, sp.piece(), i);
+ } else if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
+ // TODO(taku): implement USER_DEFINED symbol.
+ LOG(FATAL) << "User defined symbol is not supported in BPE";
+ } else {
+ port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
+ }
+ }
+}
+
+Model::~Model() {}
+
+std::vector<std::pair<StringPiece, int>> Model::Encode(
+ StringPiece normalized) const {
+ if (normalized.empty()) {
+ return {};
+ }
+
+ struct SymbolPair {
+ int left; // left index of this pair
+ int right; // right index of this pair
+ float score; // score of this pair. large is better.
+ size_t size; // length of this piece
+ };
+
+ class SymbolPairComparator {
+ public:
+ const bool operator()(SymbolPair *h1, SymbolPair *h2) {
+ return (h1->score < h2->score ||
+ (h1->score == h2->score && h1->left > h2->left));
+ }
+ };
+
+ struct Symbol {
+ int prev; // prev index of this symbol. -1 for BOS.
+ int next; // next index of tihs symbol. -1 for EOS.
+ StringPiece piece;
+ };
+
+ using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
+ SymbolPairComparator>;
+ Agenda agenda;
+ std::vector<Symbol> symbols;
+ symbols.reserve(normalized.size());
+
+ // Lookup new symbol pair at [left, right] and inserts it to agenda.
+ auto MaybeAddNewSymbolPair = [this, &symbols, &agenda](int left, int right) {
+ if (left == -1 || right == -1) return;
+ const StringPiece piece(
+ symbols[left].piece.data(),
+ symbols[left].piece.size() + symbols[right].piece.size());
+ const auto it = pieces_.find(piece);
+ if (it == pieces_.end()) {
+ return;
+ }
+ auto *h = new SymbolPair;
+ h->left = left;
+ h->right = right;
+ h->score = GetScore(it->second);
+ h->size = piece.size();
+ agenda.push(h);
+ };
+
+ // Splits the input into character sequence
+ const char *begin = normalized.data();
+ const char *end = normalized.data() + normalized.size();
+ int index = 0;
+ while (begin < end) {
+ int mblen = string_util::OneCharLen(begin);
+ if (mblen > end - begin) {
+ LOG(ERROR) << "Invalid character length.";
+ mblen = end - begin;
+ }
+ Symbol s;
+ s.piece = StringPiece(begin, mblen);
+ s.prev = begin == normalized.data() ? -1 : index - 1;
+ begin += mblen;
+ s.next = begin == end ? -1 : index + 1;
+ ++index;
+ symbols.emplace_back(s);
+ }
+ CHECK(!symbols.empty());
+
+ // Lookup all bigrams.
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ MaybeAddNewSymbolPair(i - 1, i);
+ }
+
+ // Main loop.
+ while (!agenda.empty()) {
+ std::unique_ptr<SymbolPair> top(agenda.top());
+ agenda.pop();
+
+ // |top| is no longer available.
+ if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
+ symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
+ top->size) {
+ continue;
+ }
+
+ // Replaces symbols with |top| rule.
+ symbols[top->left].piece = StringPiece(
+ symbols[top->left].piece.data(),
+ symbols[top->left].piece.size() + symbols[top->right].piece.size());
+
+ // Updates prev/next pointers.
+ symbols[top->left].next = symbols[top->right].next;
+ if (symbols[top->right].next >= 0) {
+ symbols[symbols[top->right].next].prev = top->left;
+ }
+ symbols[top->right].piece = StringPiece("");
+
+ // Adds new symbol pairs which are newly added after symbol replacement.
+ MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
+ MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
+ }
+
+ std::vector<std::pair<StringPiece, int>> output;
+ for (int index = 0; index != -1; index = symbols[index].next) {
+ CHECK_GE(index, 0);
+ CHECK_LT(index, static_cast<int>(symbols.size()));
+ output.emplace_back(symbols[index].piece, PieceToId(symbols[index].piece));
+ }
+
+ return output;
+}
+} // namespace bpe
+} // namespace sentencepiece