Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2017-03-07 13:43:50 +0300
committerTaku Kudo <taku@google.com>2017-03-07 13:43:50 +0300
commit2928ce5307224ea4c012fc6cbd7a098c486590b6 (patch)
tree38b679886855a7a6b80fdc61f2f62c952cf3bfb7 /src/bpe_model_trainer.cc
Initialize repository
Diffstat (limited to 'src/bpe_model_trainer.cc')
-rw-r--r--src/bpe_model_trainer.cc323
1 files changed, 323 insertions, 0 deletions
diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc
new file mode 100644
index 0000000..f80dd75
--- /dev/null
+++ b/src/bpe_model_trainer.cc
@@ -0,0 +1,323 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#include "bpe_model_trainer.h"
+
+#include <unordered_set>
+#include "util.h"
+
+namespace sentencepiece {
+namespace bpe {
+
+std::string Trainer::Symbol::ToString() const {
+ return string_util::UnicodeTextToUTF8(chars);
+}
+
+Trainer::Symbol *Trainer::GetCharSymbol(char32 c) {
+ const uint64 freq = port::FindWithDefault(required_chars_, c, 1);
+ CHECK_GT(freq, 0);
+ const auto it = symbols_cache_.find(c);
+ if (it != symbols_cache_.end()) {
+ return it->second;
+ }
+ Symbol *s = new Symbol;
+ allocated_.push_back(s);
+ s->is_unk = (kUNKChar == c);
+ s->fp = c;
+ s->chars.push_back(c);
+ s->freq = freq;
+ port::InsertOrDie(&symbols_cache_, s->fp, s);
+ return s;
+}
+
+Trainer::Symbol *Trainer::GetPairSymbol(const Symbol *left,
+ const Symbol *right) {
+ if (left == nullptr || right == nullptr || left->is_unk || right->is_unk) {
+ return nullptr;
+ }
+
+ const uint64 fp = port::FingerprintCat(left->fp, right->fp);
+ const auto it = symbols_cache_.find(fp);
+ if (it != symbols_cache_.end()) {
+ return it->second;
+ }
+
+ CHECK(!left->chars.empty());
+ CHECK(!right->chars.empty());
+ string_util::UnicodeText ut;
+ for (const char32 c : left->chars) ut.push_back(c);
+ for (const char32 c : right->chars) ut.push_back(c);
+
+ // Do not make an invalid piece.
+ if (!IsValidSentencePiece(ut)) {
+ return nullptr;
+ }
+
+ Symbol *s = new Symbol;
+ allocated_.push_back(s);
+ s->fp = fp;
+ s->left = left;
+ s->right = right;
+ s->chars = ut;
+ port::InsertOrDie(&symbols_cache_, s->fp, s);
+ return s;
+}
+
+void Trainer::ComputeFreq(Symbol *symbol) const {
+ if (symbol->freq > 0) { // if freq == 0, re-computation is required.
+ return;
+ }
+ // Avoids double-count. ("AAA" => only count the first "AA").
+ Position prev_pos = {-1, 0};
+ CHECK_EQ(0, symbol->freq);
+ for (auto it = symbol->positions.begin(); it != symbol->positions.end();) {
+ const Position pos = DecodePos(*it);
+ // There are two same bigrams in "AAA", [AA] [AA], and we want to
+ // remove the second one to avoid double counts.
+ // If the right symbol in the first bigram and the left symbol in the
+ // second bigram have the same position, (pos.left == prev_pos.right),
+ // duplicated bigram exisit.
+ // Also, symbols_[sid][left] and symbols_[sid]right] must store
+ // the same symbols in symbol->left and symbols->right.
+ if ((pos.sid == prev_pos.sid && pos.left == prev_pos.right) ||
+ symbol->left != symbols_[pos.sid][pos.left] ||
+ symbol->right != symbols_[pos.sid][pos.right]) {
+ it = symbol->positions.erase(it);
+ // Initializes prev_pos.
+ // In "AAAA", the last "AA" can be counted.
+ prev_pos = {-1, 0};
+ } else {
+ symbol->freq += sentences_[pos.sid].second;
+ prev_pos = pos;
+ ++it;
+ }
+ }
+}
+
+int Trainer::GetNextIndex(int sid, int index) const {
+ for (size_t i = index + 1; i < symbols_[sid].size(); ++i) {
+ if (symbols_[sid][i] == nullptr) continue;
+ return i;
+ }
+ return -1;
+}
+
+int Trainer::GetPrevIndex(int sid, int index) const {
+ for (int i = index - 1; i >= 0; --i) {
+ if (symbols_[sid][i] == nullptr) continue;
+ return i;
+ }
+ return -1;
+}
+
+void Trainer::AddNewPair(int sid, int left, int right) {
+ if (left == -1 || right == -1) return;
+ auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
+ if (symbol != nullptr) {
+ active_symbols_.insert(symbol);
+ symbol->positions.insert(EncodePos(sid, left, right));
+ }
+}
+
+void Trainer::ResetFreq(int sid, int left, int right, const Symbol *best) {
+ if (left == -1 || right == -1) return;
+ auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
+ if (symbol != nullptr && symbol != best) {
+ symbol->freq = 0;
+ }
+}
+
+void Trainer::UpdateActiveSymbols() {
+ std::vector<Symbol *> symbols;
+ for (auto &it : symbols_cache_) {
+ Symbol *symbol = it.second;
+ if (symbol->IsBigram()) {
+ ComputeFreq(symbol);
+ symbols.push_back(symbol);
+ }
+ }
+
+ // At least kMinActiveSymbolsSize symbols must be in |active_symbols_|.
+ constexpr int kMinActiveSymbolsSize = 1000;
+
+ // Keeps top 5% frequent symbols.
+ constexpr float kTopFrequentRatio = 0.05;
+ const int size =
+ std::min<int>(std::max<int>(kMinActiveSymbolsSize,
+ symbols_cache_.size() * kTopFrequentRatio),
+ symbols.size());
+
+ std::partial_sort(symbols.begin(), symbols.begin() + size, symbols.end(),
+ [](Symbol *s1, Symbol *s2) { return s1->freq > s2->freq; });
+ LOG(INFO) << "Updating active symbols. max_freq=" << symbols[0]->freq
+ << " min_freq=" << symbols[size - 1]->freq;
+
+ active_symbols_.clear();
+ active_symbols_.insert(symbols.begin(), symbols.begin() + size);
+}
+
+void Trainer::Train() {
+#define CHECK_RANGE(variable, minval, maxval) \
+ CHECK(variable >= minval && variable <= maxval)
+
+ CHECK_GT(trainer_spec_.input().size(), 0);
+ CHECK(!trainer_spec_.model_prefix().empty());
+ CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
+ CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
+ CHECK_RANGE(trainer_spec_.max_sentencepiece_length(), 1, 64);
+ CHECK_GT(trainer_spec_.vocab_size(), 0);
+#undef CHECK_RANGE
+
+ LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
+
+ CHECK(normalizer_spec_.escape_whitespaces());
+ CHECK_EQ(TrainerSpec::BPE, trainer_spec_.model_type());
+
+ symbols_.clear();
+ allocated_.clear();
+ symbols_cache_.clear();
+ active_symbols_.clear();
+
+ // Load all sentences
+ LoadSentences();
+
+ if (trainer_spec_.split_by_whitespace()) {
+ SplitSentencesByWhitespace();
+ }
+
+ // Initializes symbols_. symbols_[sid][i] stores an unary symbol.
+ symbols_.resize(sentences_.size());
+ for (size_t i = 0; i < sentences_.size(); ++i) {
+ for (const char32 c : string_util::UTF8ToUnicodeText(sentences_[i].first)) {
+ symbols_[i].push_back(GetCharSymbol(c));
+ }
+ }
+
+ // Makes all bigram symbols.
+ for (size_t sid = 0; sid < symbols_.size(); ++sid) {
+ for (size_t i = 1; i < symbols_[sid].size(); ++i) {
+ AddNewPair(sid, i - 1, i);
+ }
+ }
+
+ const int meta_symbols_size = trainer_spec_.control_symbols().size() +
+ trainer_spec_.user_defined_symbols().size() +
+ 3; // <s>, </s>, <unk>
+ const int vocab_size =
+ trainer_spec_.vocab_size() - meta_symbols_size - required_chars_.size();
+ CHECK_GE(vocab_size, 0);
+
+ // We may see duplicated pieces that are extracted with different path.
+ // In real segmentation phase, we can consider them as one symbol.
+ // e.g., "aaa" => "aa" + "a" or "a" + "aa".
+ std::unordered_set<std::string> dup;
+
+ // Main loop.
+ CHECK(final_pieces_.empty());
+ while (final_pieces_.size() < static_cast<size_t>(vocab_size)) {
+ constexpr int kUpdateActiveSymbolsInteval = 100;
+ if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) {
+ UpdateActiveSymbols();
+ }
+
+ // Scanning active symbols, finds the best_symbol with highest freq.
+ Symbol *best_symbol = nullptr;
+ for (auto &it : active_symbols_) {
+ Symbol *symbol = it;
+ ComputeFreq(symbol);
+ // If the frequency is the same, take shorter symbol.
+ // if the length is the same, use lexicographical comparison
+ if (best_symbol == nullptr ||
+ (symbol->freq > best_symbol->freq ||
+ (symbol->freq == best_symbol->freq &&
+ (symbol->chars.size() < best_symbol->chars.size() ||
+ (symbol->chars.size() == best_symbol->chars.size() &&
+ symbol->ToString() < best_symbol->ToString()))))) {
+ best_symbol = symbol;
+ }
+ }
+
+ if (best_symbol == nullptr) {
+ LOG(WARNING) << "No valid symbol found";
+ break;
+ }
+
+ if (!dup.insert(best_symbol->ToString()).second) {
+ // Removes best_symbol so it is not selected again.
+ symbols_cache_.erase(best_symbol->fp);
+ active_symbols_.erase(best_symbol);
+ continue;
+ }
+
+ // Stores the best_symbol in the final output.
+ const float score = -final_pieces_.size();
+ final_pieces_.emplace_back(best_symbol->ToString(), score);
+
+ if (final_pieces_.size() % 20 == 0) {
+ LOG(INFO) << "Added: freq=" << best_symbol->freq
+ << " size=" << final_pieces_.size()
+ << " all=" << symbols_cache_.size()
+ << " active=" << active_symbols_.size()
+ << " piece=" << best_symbol->ToString();
+ }
+
+ // Add new bigrams which are created after symbol replacement.
+ // We do not need to scan all characters, but scan the neighbors in
+ // best_symbol.
+ for (const uint64 &encoded_pos : best_symbol->positions) {
+ const Position pos = DecodePos(encoded_pos);
+
+ if (symbols_[pos.sid][pos.left] == nullptr) {
+ // left index might be NULL (set in the privous iteration)
+ // when left_symbol == right_symbol.
+ continue;
+ }
+ CHECK_NOTNULL(symbols_[pos.sid][pos.right]);
+
+ // We have three bigrams [prev, left], [left, right], [right, next],
+ // which are affected with this symbol replacement.
+ const int next = GetNextIndex(pos.sid, pos.right);
+ const int prev = GetPrevIndex(pos.sid, pos.left);
+
+ // Resets the frequencies of bigrams [prev, left] and [right, next].
+ ResetFreq(pos.sid, prev, pos.left, best_symbol);
+ ResetFreq(pos.sid, pos.right, next, best_symbol);
+
+ // Merges two symbols.
+ symbols_[pos.sid][pos.left] = best_symbol;
+ symbols_[pos.sid][pos.right] = nullptr;
+
+ // Makes new symbol bigrams [prev, left] and [left, next].
+ AddNewPair(pos.sid, prev, pos.left);
+ AddNewPair(pos.sid, pos.left, next);
+ }
+
+ // Removes best_symbol so it is not selected again.
+ symbols_cache_.erase(best_symbol->fp);
+ active_symbols_.erase(best_symbol);
+ } // end of main loop
+
+ // Adds required_chars_
+ for (const auto &w : Sorted(required_chars_)) {
+ const Symbol *symbol = GetCharSymbol(w.first);
+ const float score = -final_pieces_.size();
+ final_pieces_.emplace_back(symbol->ToString(), score);
+ }
+
+ Save();
+
+ port::STLDeleteElements(&allocated_);
+}
+} // namespace bpe
+} // namespace sentencepiece