// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #include "bpe_model_trainer.h" #include #include #include #include #include "util.h" namespace sentencepiece { namespace bpe { std::string Trainer::Symbol::ToString() const { return string_util::UnicodeTextToUTF8(chars); } Trainer::Symbol *Trainer::GetCharSymbol(char32 c) { const uint64 freq = port::FindWithDefault(required_chars_, c, 1); CHECK_GT(freq, 0); const auto it = symbols_cache_.find(c); if (it != symbols_cache_.end()) { return it->second; } Symbol *s = new Symbol; allocated_.push_back(s); s->is_unk = (kUNKChar == c); s->fp = c; s->chars.push_back(c); s->freq = freq; port::InsertOrDie(&symbols_cache_, s->fp, s); return s; } Trainer::Symbol *Trainer::GetPairSymbol(const Symbol *left, const Symbol *right) { if (left == nullptr || right == nullptr || left->is_unk || right->is_unk) { return nullptr; } const uint64 fp = port::FingerprintCat(left->fp, right->fp); const auto it = symbols_cache_.find(fp); if (it != symbols_cache_.end()) { return it->second; } CHECK(!left->chars.empty()); CHECK(!right->chars.empty()); string_util::UnicodeText ut; for (const char32 c : left->chars) ut.push_back(c); for (const char32 c : right->chars) ut.push_back(c); // Do not make an invalid piece. if (!IsValidSentencePiece(ut)) { return nullptr; } Symbol *s = new Symbol; allocated_.push_back(s); s->fp = fp; s->left = left; s->right = right; s->chars = ut; port::InsertOrDie(&symbols_cache_, s->fp, s); return s; } void Trainer::ComputeFreq(Symbol *symbol) const { if (symbol->freq > 0) { // if freq == 0, re-computation is required. return; } // Avoids double-count. ("AAA" => only count the first "AA"). Position prev_pos = {-1, 0}; CHECK_EQ(0, symbol->freq); for (auto it = symbol->positions.begin(); it != symbol->positions.end();) { const Position pos = DecodePos(*it); // There are two same bigrams in "AAA", [AA] [AA], and we want to // remove the second one to avoid double counts. // If the right symbol in the first bigram and the left symbol in the // second bigram have the same position, (pos.left == prev_pos.right), // duplicated bigram exisit. // Also, symbols_[sid][left] and symbols_[sid]right] must store // the same symbols in symbol->left and symbols->right. if ((pos.sid == prev_pos.sid && pos.left == prev_pos.right) || symbol->left != symbols_[pos.sid][pos.left] || symbol->right != symbols_[pos.sid][pos.right]) { it = symbol->positions.erase(it); // Initializes prev_pos. // In "AAAA", the last "AA" can be counted. prev_pos = {-1, 0}; } else { symbol->freq += sentences_[pos.sid].second; prev_pos = pos; ++it; } } } int Trainer::GetNextIndex(int sid, int index) const { for (size_t i = index + 1; i < symbols_[sid].size(); ++i) { if (symbols_[sid][i] == nullptr) continue; return i; } return -1; } int Trainer::GetPrevIndex(int sid, int index) const { for (int i = index - 1; i >= 0; --i) { if (symbols_[sid][i] == nullptr) continue; return i; } return -1; } void Trainer::AddNewPair(int sid, int left, int right) { if (left == -1 || right == -1) return; auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]); if (symbol != nullptr) { active_symbols_.insert(symbol); symbol->positions.insert(EncodePos(sid, left, right)); } } void Trainer::ResetFreq(int sid, int left, int right, const Symbol *best) { if (left == -1 || right == -1) return; auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]); if (symbol != nullptr && symbol != best) { symbol->freq = 0; } } void Trainer::UpdateActiveSymbols() { std::vector symbols; for (auto &it : symbols_cache_) { Symbol *symbol = it.second; if (symbol->IsBigram()) { ComputeFreq(symbol); symbols.push_back(symbol); } } // At least kMinActiveSymbolsSize symbols must be in |active_symbols_|. constexpr int kMinActiveSymbolsSize = 1000; // Keeps top 5% frequent symbols. constexpr float kTopFrequentRatio = 0.05; const int size = std::min(std::max(kMinActiveSymbolsSize, symbols_cache_.size() * kTopFrequentRatio), symbols.size()); std::partial_sort(symbols.begin(), symbols.begin() + size, symbols.end(), [](Symbol *s1, Symbol *s2) { return s1->freq > s2->freq; }); LOG(INFO) << "Updating active symbols. max_freq=" << symbols[0]->freq << " min_freq=" << symbols[size - 1]->freq; active_symbols_.clear(); active_symbols_.insert(symbols.begin(), symbols.begin() + size); } util::Status Trainer::Train() { RETURN_IF_ERROR(status()); LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString(); CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces()); CHECK_EQ_OR_RETURN(TrainerSpec::BPE, trainer_spec_.model_type()); symbols_.clear(); allocated_.clear(); symbols_cache_.clear(); active_symbols_.clear(); // Load all sentences RETURN_IF_ERROR(LoadSentences()); if (trainer_spec_.split_by_whitespace()) { SplitSentencesByWhitespace(); } // Initializes symbols_. symbols_[sid][i] stores an unary symbol. symbols_.resize(sentences_.size()); for (size_t i = 0; i < sentences_.size(); ++i) { for (const char32 c : string_util::UTF8ToUnicodeText(sentences_[i].first)) { symbols_[i].push_back(GetCharSymbol(c)); } } // Makes all bigram symbols. for (size_t sid = 0; sid < symbols_.size(); ++sid) { for (size_t i = 1; i < symbols_[sid].size(); ++i) { AddNewPair(sid, i - 1, i); } } const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size() - required_chars_.size(); CHECK_GE_OR_RETURN(vocab_size, 0); // We may see duplicated pieces that are extracted with different path. // In real segmentation phase, we can consider them as one symbol. // e.g., "aaa" => "aa" + "a" or "a" + "aa". std::unordered_set dup; // Main loop. CHECK_OR_RETURN(final_pieces_.empty()); while (final_pieces_.size() < static_cast(vocab_size)) { constexpr int kUpdateActiveSymbolsInteval = 100; if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) { UpdateActiveSymbols(); } // Scanning active symbols, finds the best_symbol with highest freq. Symbol *best_symbol = nullptr; for (auto &it : active_symbols_) { Symbol *symbol = it; ComputeFreq(symbol); // If the frequency is the same, take shorter symbol. // if the length is the same, use lexicographical comparison if (best_symbol == nullptr || (symbol->freq > best_symbol->freq || (symbol->freq == best_symbol->freq && (symbol->chars.size() < best_symbol->chars.size() || (symbol->chars.size() == best_symbol->chars.size() && symbol->ToString() < best_symbol->ToString()))))) { best_symbol = symbol; } } if (best_symbol == nullptr) { LOG(WARNING) << "No valid symbol found"; break; } if (!dup.insert(best_symbol->ToString()).second) { // Removes best_symbol so it is not selected again. symbols_cache_.erase(best_symbol->fp); active_symbols_.erase(best_symbol); continue; } // Stores the best_symbol in the final output. final_pieces_.emplace_back(best_symbol->ToString(), -static_cast(final_pieces_.size())); if (final_pieces_.size() % 20 == 0) { LOG(INFO) << "Added: freq=" << best_symbol->freq << " size=" << final_pieces_.size() << " all=" << symbols_cache_.size() << " active=" << active_symbols_.size() << " piece=" << best_symbol->ToString(); } // Add new bigrams which are created after symbol replacement. // We do not need to scan all characters, but scan the neighbors in // best_symbol. for (const uint64 &encoded_pos : best_symbol->positions) { const Position pos = DecodePos(encoded_pos); if (symbols_[pos.sid][pos.left] == nullptr) { // left index might be NULL (set in the privous iteration) // when left_symbol == right_symbol. continue; } CHECK_OR_RETURN(symbols_[pos.sid][pos.right]); // We have three bigrams [prev, left], [left, right], [right, next], // which are affected with this symbol replacement. const int next = GetNextIndex(pos.sid, pos.right); const int prev = GetPrevIndex(pos.sid, pos.left); // Resets the frequencies of bigrams [prev, left] and [right, next]. ResetFreq(pos.sid, prev, pos.left, best_symbol); ResetFreq(pos.sid, pos.right, next, best_symbol); // Merges two symbols. symbols_[pos.sid][pos.left] = best_symbol; symbols_[pos.sid][pos.right] = nullptr; // Makes new symbol bigrams [prev, left] and [left, next]. AddNewPair(pos.sid, prev, pos.left); AddNewPair(pos.sid, pos.left, next); } // Removes best_symbol so it is not selected again. symbols_cache_.erase(best_symbol->fp); active_symbols_.erase(best_symbol); } // end of main loop // Adds required_chars_ for (const auto &w : Sorted(required_chars_)) { const Symbol *symbol = GetCharSymbol(w.first); final_pieces_.emplace_back(symbol->ToString(), -static_cast(final_pieces_.size())); } port::STLDeleteElements(&allocated_); return Save(); } } // namespace bpe } // namespace sentencepiece