diff options
author | Taku Kudo <taku@google.com> | 2017-03-07 13:43:50 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2017-03-07 13:43:50 +0300 |
commit | 2928ce5307224ea4c012fc6cbd7a098c486590b6 (patch) | |
tree | 38b679886855a7a6b80fdc61f2f62c952cf3bfb7 /src/bpe_model_trainer.cc |
Initialize repository
Diffstat (limited to 'src/bpe_model_trainer.cc')
-rw-r--r-- | src/bpe_model_trainer.cc | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc new file mode 100644 index 0000000..f80dd75 --- /dev/null +++ b/src/bpe_model_trainer.cc @@ -0,0 +1,323 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +#include "bpe_model_trainer.h" + +#include <unordered_set> +#include "util.h" + +namespace sentencepiece { +namespace bpe { + +std::string Trainer::Symbol::ToString() const { + return string_util::UnicodeTextToUTF8(chars); +} + +Trainer::Symbol *Trainer::GetCharSymbol(char32 c) { + const uint64 freq = port::FindWithDefault(required_chars_, c, 1); + CHECK_GT(freq, 0); + const auto it = symbols_cache_.find(c); + if (it != symbols_cache_.end()) { + return it->second; + } + Symbol *s = new Symbol; + allocated_.push_back(s); + s->is_unk = (kUNKChar == c); + s->fp = c; + s->chars.push_back(c); + s->freq = freq; + port::InsertOrDie(&symbols_cache_, s->fp, s); + return s; +} + +Trainer::Symbol *Trainer::GetPairSymbol(const Symbol *left, + const Symbol *right) { + if (left == nullptr || right == nullptr || left->is_unk || right->is_unk) { + return nullptr; + } + + const uint64 fp = port::FingerprintCat(left->fp, right->fp); + const auto it = symbols_cache_.find(fp); + if (it != symbols_cache_.end()) { + return it->second; + } + + CHECK(!left->chars.empty()); + CHECK(!right->chars.empty()); + string_util::UnicodeText ut; + for (const char32 c : left->chars) ut.push_back(c); + for (const char32 c : right->chars) ut.push_back(c); + + // Do not make an invalid piece. + if (!IsValidSentencePiece(ut)) { + return nullptr; + } + + Symbol *s = new Symbol; + allocated_.push_back(s); + s->fp = fp; + s->left = left; + s->right = right; + s->chars = ut; + port::InsertOrDie(&symbols_cache_, s->fp, s); + return s; +} + +void Trainer::ComputeFreq(Symbol *symbol) const { + if (symbol->freq > 0) { // if freq == 0, re-computation is required. + return; + } + // Avoids double-count. ("AAA" => only count the first "AA"). + Position prev_pos = {-1, 0}; + CHECK_EQ(0, symbol->freq); + for (auto it = symbol->positions.begin(); it != symbol->positions.end();) { + const Position pos = DecodePos(*it); + // There are two same bigrams in "AAA", [AA] [AA], and we want to + // remove the second one to avoid double counts. + // If the right symbol in the first bigram and the left symbol in the + // second bigram have the same position, (pos.left == prev_pos.right), + // duplicated bigram exisit. + // Also, symbols_[sid][left] and symbols_[sid]right] must store + // the same symbols in symbol->left and symbols->right. + if ((pos.sid == prev_pos.sid && pos.left == prev_pos.right) || + symbol->left != symbols_[pos.sid][pos.left] || + symbol->right != symbols_[pos.sid][pos.right]) { + it = symbol->positions.erase(it); + // Initializes prev_pos. + // In "AAAA", the last "AA" can be counted. + prev_pos = {-1, 0}; + } else { + symbol->freq += sentences_[pos.sid].second; + prev_pos = pos; + ++it; + } + } +} + +int Trainer::GetNextIndex(int sid, int index) const { + for (size_t i = index + 1; i < symbols_[sid].size(); ++i) { + if (symbols_[sid][i] == nullptr) continue; + return i; + } + return -1; +} + +int Trainer::GetPrevIndex(int sid, int index) const { + for (int i = index - 1; i >= 0; --i) { + if (symbols_[sid][i] == nullptr) continue; + return i; + } + return -1; +} + +void Trainer::AddNewPair(int sid, int left, int right) { + if (left == -1 || right == -1) return; + auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]); + if (symbol != nullptr) { + active_symbols_.insert(symbol); + symbol->positions.insert(EncodePos(sid, left, right)); + } +} + +void Trainer::ResetFreq(int sid, int left, int right, const Symbol *best) { + if (left == -1 || right == -1) return; + auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]); + if (symbol != nullptr && symbol != best) { + symbol->freq = 0; + } +} + +void Trainer::UpdateActiveSymbols() { + std::vector<Symbol *> symbols; + for (auto &it : symbols_cache_) { + Symbol *symbol = it.second; + if (symbol->IsBigram()) { + ComputeFreq(symbol); + symbols.push_back(symbol); + } + } + + // At least kMinActiveSymbolsSize symbols must be in |active_symbols_|. + constexpr int kMinActiveSymbolsSize = 1000; + + // Keeps top 5% frequent symbols. + constexpr float kTopFrequentRatio = 0.05; + const int size = + std::min<int>(std::max<int>(kMinActiveSymbolsSize, + symbols_cache_.size() * kTopFrequentRatio), + symbols.size()); + + std::partial_sort(symbols.begin(), symbols.begin() + size, symbols.end(), + [](Symbol *s1, Symbol *s2) { return s1->freq > s2->freq; }); + LOG(INFO) << "Updating active symbols. max_freq=" << symbols[0]->freq + << " min_freq=" << symbols[size - 1]->freq; + + active_symbols_.clear(); + active_symbols_.insert(symbols.begin(), symbols.begin() + size); +} + +void Trainer::Train() { +#define CHECK_RANGE(variable, minval, maxval) \ + CHECK(variable >= minval && variable <= maxval) + + CHECK_GT(trainer_spec_.input().size(), 0); + CHECK(!trainer_spec_.model_prefix().empty()); + CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0); + CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000); + CHECK_RANGE(trainer_spec_.max_sentencepiece_length(), 1, 64); + CHECK_GT(trainer_spec_.vocab_size(), 0); +#undef CHECK_RANGE + + LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString(); + + CHECK(normalizer_spec_.escape_whitespaces()); + CHECK_EQ(TrainerSpec::BPE, trainer_spec_.model_type()); + + symbols_.clear(); + allocated_.clear(); + symbols_cache_.clear(); + active_symbols_.clear(); + + // Load all sentences + LoadSentences(); + + if (trainer_spec_.split_by_whitespace()) { + SplitSentencesByWhitespace(); + } + + // Initializes symbols_. symbols_[sid][i] stores an unary symbol. + symbols_.resize(sentences_.size()); + for (size_t i = 0; i < sentences_.size(); ++i) { + for (const char32 c : string_util::UTF8ToUnicodeText(sentences_[i].first)) { + symbols_[i].push_back(GetCharSymbol(c)); + } + } + + // Makes all bigram symbols. + for (size_t sid = 0; sid < symbols_.size(); ++sid) { + for (size_t i = 1; i < symbols_[sid].size(); ++i) { + AddNewPair(sid, i - 1, i); + } + } + + const int meta_symbols_size = trainer_spec_.control_symbols().size() + + trainer_spec_.user_defined_symbols().size() + + 3; // <s>, </s>, <unk> + const int vocab_size = + trainer_spec_.vocab_size() - meta_symbols_size - required_chars_.size(); + CHECK_GE(vocab_size, 0); + + // We may see duplicated pieces that are extracted with different path. + // In real segmentation phase, we can consider them as one symbol. + // e.g., "aaa" => "aa" + "a" or "a" + "aa". + std::unordered_set<std::string> dup; + + // Main loop. + CHECK(final_pieces_.empty()); + while (final_pieces_.size() < static_cast<size_t>(vocab_size)) { + constexpr int kUpdateActiveSymbolsInteval = 100; + if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) { + UpdateActiveSymbols(); + } + + // Scanning active symbols, finds the best_symbol with highest freq. + Symbol *best_symbol = nullptr; + for (auto &it : active_symbols_) { + Symbol *symbol = it; + ComputeFreq(symbol); + // If the frequency is the same, take shorter symbol. + // if the length is the same, use lexicographical comparison + if (best_symbol == nullptr || + (symbol->freq > best_symbol->freq || + (symbol->freq == best_symbol->freq && + (symbol->chars.size() < best_symbol->chars.size() || + (symbol->chars.size() == best_symbol->chars.size() && + symbol->ToString() < best_symbol->ToString()))))) { + best_symbol = symbol; + } + } + + if (best_symbol == nullptr) { + LOG(WARNING) << "No valid symbol found"; + break; + } + + if (!dup.insert(best_symbol->ToString()).second) { + // Removes best_symbol so it is not selected again. + symbols_cache_.erase(best_symbol->fp); + active_symbols_.erase(best_symbol); + continue; + } + + // Stores the best_symbol in the final output. + const float score = -final_pieces_.size(); + final_pieces_.emplace_back(best_symbol->ToString(), score); + + if (final_pieces_.size() % 20 == 0) { + LOG(INFO) << "Added: freq=" << best_symbol->freq + << " size=" << final_pieces_.size() + << " all=" << symbols_cache_.size() + << " active=" << active_symbols_.size() + << " piece=" << best_symbol->ToString(); + } + + // Add new bigrams which are created after symbol replacement. + // We do not need to scan all characters, but scan the neighbors in + // best_symbol. + for (const uint64 &encoded_pos : best_symbol->positions) { + const Position pos = DecodePos(encoded_pos); + + if (symbols_[pos.sid][pos.left] == nullptr) { + // left index might be NULL (set in the privous iteration) + // when left_symbol == right_symbol. + continue; + } + CHECK_NOTNULL(symbols_[pos.sid][pos.right]); + + // We have three bigrams [prev, left], [left, right], [right, next], + // which are affected with this symbol replacement. + const int next = GetNextIndex(pos.sid, pos.right); + const int prev = GetPrevIndex(pos.sid, pos.left); + + // Resets the frequencies of bigrams [prev, left] and [right, next]. + ResetFreq(pos.sid, prev, pos.left, best_symbol); + ResetFreq(pos.sid, pos.right, next, best_symbol); + + // Merges two symbols. + symbols_[pos.sid][pos.left] = best_symbol; + symbols_[pos.sid][pos.right] = nullptr; + + // Makes new symbol bigrams [prev, left] and [left, next]. + AddNewPair(pos.sid, prev, pos.left); + AddNewPair(pos.sid, pos.left, next); + } + + // Removes best_symbol so it is not selected again. + symbols_cache_.erase(best_symbol->fp); + active_symbols_.erase(best_symbol); + } // end of main loop + + // Adds required_chars_ + for (const auto &w : Sorted(required_chars_)) { + const Symbol *symbol = GetCharSymbol(w.first); + const float score = -final_pieces_.size(); + final_pieces_.emplace_back(symbol->ToString(), score); + } + + Save(); + + port::STLDeleteElements(&allocated_); +} +} // namespace bpe +} // namespace sentencepiece |