diff options
author | Taku Kudo <taku@google.com> | 2018-09-16 11:40:16 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-09-16 11:40:16 +0300 |
commit | b66b41641385b592a1a20431cbf4a775466a79a2 (patch) | |
tree | 79c67b7421f2fa1eb16f32595def3a795182d8d6 | |
parent | c14d581837868d67b17abfc2b8b382c1c9c3660d (diff) |
performance tuningsr
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/bpe_model.cc | 17 | ||||
-rw-r--r-- | src/freelist.h | 82 | ||||
-rw-r--r-- | src/freelist_test.cc | 34 | ||||
-rw-r--r-- | src/model_interface.cc | 29 | ||||
-rw-r--r-- | src/model_interface.h | 62 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 5 | ||||
-rw-r--r-- | src/sentencepiece_processor.h | 2 | ||||
-rw-r--r-- | src/sentencepiece_processor_test.cc | 5 | ||||
-rw-r--r-- | src/unigram_model.cc | 75 | ||||
-rw-r--r-- | src/unigram_model.h | 51 | ||||
-rw-r--r-- | src/unigram_model_trainer.cc | 9 | ||||
-rw-r--r-- | src/unigram_model_trainer.h | 26 |
13 files changed, 256 insertions, 142 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 96be210..ebfcaa6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,6 +29,7 @@ set(SPM_SRCS common.h normalizer.h util.h + freelist.h filesystem.h flags.h sentencepiece_processor.h diff --git a/src/bpe_model.cc b/src/bpe_model.cc index e97fb67..26c98ed 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -20,6 +20,7 @@ #include <unordered_map> #include <utility> #include <vector> +#include "freelist.h" #include "util.h" namespace sentencepiece { @@ -73,9 +74,13 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode( string_util::string_view_hash> rev_merge; + // Pre-allocates SymbolPair for efficiency. + constexpr size_t kPreallocateSymbolPairSize = 256; + model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize); + // Lookup new symbol pair at [left, right] and inserts it to agenda. - auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge]( - int left, int right) { + auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda, + &rev_merge](int left, int right) { if (left == -1 || right == -1 || symbols[left].freeze || symbols[right].freeze) return; @@ -86,7 +91,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode( if (it == pieces_.end()) { return; } - auto *h = new SymbolPair; + auto *h = symbol_pair_allocator.Allocate(); h->left = left; h->right = right; h->score = GetScore(it->second); @@ -94,7 +99,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode( agenda.push(h); // Makes `rev_merge` for resegmentation. - if (IsUnused(it->second)) { + if (IsUnusedInlined(it->second)) { rev_merge[piece] = std::make_pair(symbols[left].piece, symbols[right].piece); } @@ -124,7 +129,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode( // Main loop. while (!agenda.empty()) { - std::unique_ptr<SymbolPair> top(agenda.top()); + SymbolPair *top = agenda.top(); agenda.pop(); // `top` is no longer available. @@ -155,7 +160,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode( resegment = [this, &resegment, &rev_merge](absl::string_view w, EncodeResult *output) -> void { const int id = PieceToId(w); - if (id == -1 || !IsUnused(id)) { + if (id == -1 || !IsUnusedInlined(id)) { output->emplace_back(w, id); return; } diff --git a/src/freelist.h b/src/freelist.h new file mode 100644 index 0000000..e39c338 --- /dev/null +++ b/src/freelist.h @@ -0,0 +1,82 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +#ifndef FREELIST_H_ +#define FREELIST_H_ + +#include <string.h> +#include <vector> + +namespace sentencepiece { +namespace model { + +// Simple FreeList that allocates a chunk of T at once. +template <class T> +class FreeList { + public: + FreeList() = delete; + explicit FreeList(size_t chunk_size) : chunk_size_(chunk_size) {} + virtual ~FreeList() { + for (auto& chunk : freelist_) delete[] chunk; + } + + // `Free` doesn't free the object but reuse the allocated memory chunks. + void Free() { + const int size = std::min<int>(chunk_index_ + 1, freelist_.size()); + for (int i = 0; i < size; ++i) { + T* chunk = freelist_[i]; + memset(chunk, 0, sizeof(*chunk) * chunk_size_); + } + chunk_index_ = 0; + element_index_ = 0; + } + + // Returns the number of allocated elements. + size_t size() const { return chunk_size_ * chunk_index_ + element_index_; } + + // Returns the element as an array. + T* operator[](size_t index) const { + return freelist_[index / chunk_size_] + index % chunk_size_; + } + + // Allocates new element. + T* Allocate() { + if (element_index_ >= chunk_size_) { + ++chunk_index_; + element_index_ = 0; + } + + if (chunk_index_ == freelist_.size()) { + T* chunk = new T[chunk_size_]; + memset(chunk, 0, sizeof(*chunk) * chunk_size_); + freelist_.push_back(chunk); + } + + T* result = freelist_[chunk_index_] + element_index_; + ++element_index_; + + return result; + } + + private: + std::vector<T*> freelist_; + + // The last element is stored at freelist_[chunk_index_][element_index_] + size_t element_index_ = 0; + size_t chunk_index_ = 0; + const size_t chunk_size_ = 0; +}; +} // namespace model +} // namespace sentencepiece +#endif // FREELIST_H_ diff --git a/src/freelist_test.cc b/src/freelist_test.cc new file mode 100644 index 0000000..a7ff7de --- /dev/null +++ b/src/freelist_test.cc @@ -0,0 +1,34 @@ +#include "freelist.h" +#include "testharness.h" + +namespace sentencepiece { +namespace model { + +TEST(FreeListTest, BasicTest) { + FreeList<int> l(5); + EXPECT_EQ(0, l.size()); + + constexpr size_t kSize = 32; + + for (size_t i = 0; i < kSize; ++i) { + int *n = l.Allocate(); + EXPECT_EQ(0, *n); + *n = i; + } + + EXPECT_EQ(kSize, l.size()); + for (size_t i = 0; i < kSize; ++i) { + EXPECT_EQ(i, *l[i]); + } + + l.Free(); + EXPECT_EQ(0, l.size()); + + // Zero-initialized after `Free`. + for (size_t i = 0; i < kSize; ++i) { + int *n = l.Allocate(); + EXPECT_EQ(0, *n); + } +} +} // namespace model +} // namespace sentencepiece diff --git a/src/model_interface.cc b/src/model_interface.cc index 0a99e02..3b21388 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -91,35 +91,6 @@ int ModelInterface::PieceToId(absl::string_view piece) const { return unk_id_; } -int ModelInterface::GetPieceSize() const { return model_proto_->pieces_size(); } - -std::string ModelInterface::IdToPiece(int id) const { - return model_proto_->pieces(id).piece(); -} - -float ModelInterface::GetScore(int id) const { - return model_proto_->pieces(id).score(); -} - -bool ModelInterface::IsControl(int id) const { - return (model_proto_->pieces(id).type() == - ModelProto::SentencePiece::CONTROL); -} - -bool ModelInterface::IsUnknown(int id) const { - return (model_proto_->pieces(id).type() == - ModelProto::SentencePiece::UNKNOWN); -} - -bool ModelInterface::IsUnused(int id) const { - return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNUSED); -} - -bool ModelInterface::IsUserDefined(int id) const { - return (model_proto_->pieces(id).type() == - ModelProto::SentencePiece::USER_DEFINED); -} - void ModelInterface::InitializePieces(bool use_prefix_matcher) { pieces_.clear(); reserved_id_map_.clear(); diff --git a/src/model_interface.h b/src/model_interface.h index a7a6350..2e3f670 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -23,6 +23,7 @@ #include <vector> #include "common.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" #include "third_party/darts_clone/darts.h" @@ -99,38 +100,79 @@ class ModelInterface { return EncodeResult(); } - // Returns the size of sentence pieces, which is the same - // as the size of vocabulary for NMT. - virtual int GetPieceSize() const; - // Returns the vocab id of `piece`. // Returns UNK(0) if `piece` is unknown virtual int PieceToId(absl::string_view piece) const; // Returns the string representation of vocab with `id`. // id must be 0 <= id < GetPieceSize(). - virtual std::string IdToPiece(int id) const; + virtual const std::string &IdToPiece(int id) const { + return model_proto_->pieces(id).piece(); + } + + // Returns the size of sentence pieces, which is the same + // as the size of vocabulary for NMT. + virtual int GetPieceSize() const { return model_proto_->pieces_size(); } // Returns the score of `id`. // Score represents a log probability of the piece. // We can roughly estimate the unigram frequency of the piece. - virtual float GetScore(int id) const; + virtual float GetScore(int id) const { + return model_proto_->pieces(id).score(); + } // Returns true if `id` is unknown symbol. - virtual bool IsUnknown(int id) const; + virtual bool IsUnknown(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::UNKNOWN); + } // Returns true if `id` is control symbol. - virtual bool IsControl(int id) const; + virtual bool IsControl(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::CONTROL); + } // Returns true if `id` is unused symbol. - virtual bool IsUnused(int id) const; + virtual bool IsUnused(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::UNUSED); + } // Returns true if `id` is user defined symbol. - virtual bool IsUserDefined(int id) const; + virtual bool IsUserDefined(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::USER_DEFINED); + } protected: void InitializePieces(bool use_prefix_matcher); + // Non-virtual (inlined) implementation for faster execution. + inline float GetScoreInlined(int id) const { + return model_proto_->pieces(id).score(); + } + + inline bool IsUnknownInlined(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::UNKNOWN); + } + + inline bool IsControlInlined(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::CONTROL); + } + + inline bool IsUnusedInlined(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::UNUSED); + } + + inline bool IsUserDefinedInlined(int id) const { + return (model_proto_->pieces(id).type() == + ModelProto::SentencePiece::USER_DEFINED); + } + const ModelProto *model_proto_ = nullptr; // PrefixMatcher for user defined symbols. diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 9bd41a0..f2f8967 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -521,8 +521,9 @@ int SentencePieceProcessor::PieceToId(util::min_string_view piece) const { return model_->PieceToId(string_util::ToSV(piece)); } -std::string SentencePieceProcessor::IdToPiece(int id) const { - CHECK_STATUS_OR_RETURN_DEFAULT(""); +const std::string &SentencePieceProcessor::IdToPiece(int id) const { + static const std::string *kEmptyString = new std::string; + CHECK_STATUS_OR_RETURN_DEFAULT(*kEmptyString); return model_->IdToPiece(id); } diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 1cd6c54..61da691 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -369,7 +369,7 @@ class SentencePieceProcessor { virtual int PieceToId(util::min_string_view piece) const; // Returns the string representation of vocab with `id`. - virtual std::string IdToPiece(int id) const; + virtual const std::string &IdToPiece(int id) const; // Returns the score of `id`. // Usually score is an emission log probability of unigram language model. diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 793826d..c1ef138 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -71,7 +71,7 @@ class MockModel : public ModelInterface { int PieceToId(absl::string_view piece) const { return 0; } - std::string IdToPiece(int id) const { return ""; } + const std::string &IdToPiece(int id) const { return kEmptyString; } float GetScore(int id) const { return 0.0; } @@ -79,6 +79,7 @@ class MockModel : public ModelInterface { absl::string_view input_; EncodeResult output_; NBestEncodeResult nbest_output_; + const std::string kEmptyString; }; std::vector<std::string> GetSpVec(const EncodeResult &pieces) { @@ -457,7 +458,7 @@ TEST(SentencepieceProcessorTest, DecodeTest) { return port::FindWithDefault(kMap, piece, 0); } - std::string IdToPiece(int id) const override { + const std::string &IdToPiece(int id) const override { static std::vector<std::string> kMap = { "<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H"}; return kMap[id]; diff --git a/src/unigram_model.cc b/src/unigram_model.cc index b7e2d02..280dec7 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -30,6 +30,9 @@ namespace sentencepiece { namespace unigram { namespace { +// Size of nodes pre-allocated in Lattice. +constexpr size_t kPreallocateLatticeNodeSize = 1024; + // Returns log(exp(x) + exp(y)). // if init_mode is true, returns log(exp(y)) == y. // log(\sum_i exp(a[i])) can be computed as @@ -50,8 +53,8 @@ inline float LogSumExp(float x, float y, bool init_mode) { } } // namespace -Lattice::Lattice() {} -Lattice::~Lattice() { Clear(); } +Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {} +Lattice::~Lattice() {} const std::vector<Lattice::Node *> &Lattice::begin_nodes(int pos) const { return begin_nodes_[pos]; @@ -77,10 +80,8 @@ Lattice::Node *Lattice::bos_node() const { return end_nodes_[0][0]; } Lattice::Node *Lattice::eos_node() const { return begin_nodes_[size()][0]; } Lattice::Node *Lattice::NewNode() { - Node *node = new Node; - memset(node, 0, sizeof(*node)); - node->node_id = all_nodes_.size(); - all_nodes_.push_back(node); + Node *node = node_allocator_.Allocate(); + node->node_id = node_allocator_.size() - 1; return node; } @@ -89,14 +90,14 @@ void Lattice::Clear() { end_nodes_.clear(); sentence_ = absl::string_view(""); surface_.clear(); - port::STLDeleteElements(&all_nodes_); - all_nodes_.clear(); + node_allocator_.Free(); } void Lattice::SetSentence(absl::string_view sentence) { Clear(); sentence_ = sentence; + surface_.reserve(sentence.size() + 1); while (!sentence.empty()) { const int mblen = std::min<int>(string_util::OneCharLen(sentence.data()), @@ -110,9 +111,10 @@ void Lattice::SetSentence(absl::string_view sentence) { begin_nodes_.resize(len + 1); end_nodes_.resize(len + 1); + constexpr size_t kReservedNodeSize = 16; for (int i = 0; i <= len; ++i) { - begin_nodes_[i].reserve(16); - end_nodes_[i].reserve(16); + begin_nodes_[i].reserve(kReservedNodeSize); + end_nodes_[i].reserve(kReservedNodeSize); } Node *bos = NewNode(); @@ -183,8 +185,8 @@ float Lattice::PopulateMarginal(float freq, // alpha and beta (accumulative log prob) in Forward Backward. // the index of alpha/beta is Node::node_id. - std::vector<float> alpha(all_nodes_.size(), 0.0); - std::vector<float> beta(all_nodes_.size(), 0.0); + std::vector<float> alpha(node_allocator_.size(), 0.0); + std::vector<float> beta(node_allocator_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -256,19 +258,13 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) { using Agenda = std::priority_queue<Hypothesis *, std::vector<Hypothesis *>, HypothesisComparator>; + constexpr size_t kPreallocatedHypothesisSize = 512; + model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize); Agenda agenda; - std::vector<Hypothesis *> allocated; std::vector<std::vector<Node *>> results; - auto NewHypothesis = [&allocated]() { - Hypothesis *h = new Hypothesis; - memset(h, 0, sizeof(*h)); - allocated.push_back(h); - return h; - }; - - auto *eos = NewHypothesis(); + auto *eos = hypothesis_allocator.Allocate(); eos->node = eos_node(); eos->next = nullptr; eos->fx = eos->node->score; @@ -297,7 +293,7 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) { // Expands new node ending at node->pos for (Node *lnode : end_nodes(node->pos)) { - auto *hyp = NewHypothesis(); + auto *hyp = hypothesis_allocator.Allocate(); hyp->node = lnode; hyp->gx = lnode->score + top->gx; // just adds node->score hyp->fx = @@ -324,7 +320,6 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) { } } - port::STLDeleteElements(&allocated); return results; } @@ -332,7 +327,7 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) { const int len = size(); if (len == 0) return {}; - std::vector<float> alpha(all_nodes_.size(), 0.0); + std::vector<float> alpha(node_allocator_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -368,18 +363,14 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) { return results; } -ModelBase::ModelBase() {} -ModelBase::~ModelBase() {} +// Model::Model() {} +// Model::~Model() {} -void ModelBase::PopulateNodes(Lattice *lattice) const { - auto GetCharsLength = [](const char *begin, int len) { - const char *end = begin + len; - int result = 0; - while (begin < end) { - begin += std::min<int>(string_util::OneCharLen(begin), end - begin); - ++result; - } - return result; +void Model::PopulateNodes(Lattice *lattice) const { + auto get_chars_length = [&lattice](int begin_pos, const char *end) { + int pos = begin_pos; + while (lattice->surface(pos) < end) ++pos; + return pos - begin_pos; }; constexpr float kUnkPenalty = 10.0; @@ -405,14 +396,15 @@ void ModelBase::PopulateNodes(Lattice *lattice) const { // Inserts pieces to the lattice. for (size_t k = 0; k < num_nodes; ++k) { - const int length = GetCharsLength(begin, trie_results[k].length); + const int length = + get_chars_length(begin_pos, begin + trie_results[k].length); const int id = trie_results[k].value; - if (IsUnused(id)) continue; + if (IsUnusedInlined(id)) continue; Lattice::Node *node = lattice->Insert(begin_pos, length); node->id = id; // the value of Trie stores vocab_id. // User defined symbol receives extra bonus to always be selected. - node->score = - IsUserDefined(id) ? (length * max_score_ + 1.0) : GetScore(id); + node->score = IsUserDefinedInlined(id) ? (length * max_score_ + 1.0) + : GetScoreInlined(id); if (!has_single_node && node->length == 1) { has_single_node = true; } @@ -426,7 +418,7 @@ void ModelBase::PopulateNodes(Lattice *lattice) const { } } -int ModelBase::PieceToId(absl::string_view piece) const { +int Model::PieceToId(absl::string_view piece) const { auto it = reserved_id_map_.find(piece); if (it != reserved_id_map_.end()) { return it->second; @@ -436,8 +428,7 @@ int ModelBase::PieceToId(absl::string_view piece) const { return id == -1 ? unk_id_ : id; } -void ModelBase::BuildTrie( - std::vector<std::pair<absl::string_view, int>> *pieces) { +void Model::BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces) { if (!status().ok()) return; if (pieces->empty()) { diff --git a/src/unigram_model.h b/src/unigram_model.h index aee61b6..466a1c2 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -21,6 +21,7 @@ #include <vector> #include "common.h" +#include "freelist.h" #include "model_interface.h" #include "sentencepiece_model.pb.h" #include "third_party/darts_clone/darts.h" @@ -35,14 +36,14 @@ class Lattice { virtual ~Lattice(); struct Node { - absl::string_view piece; // Sentence piece representation. - uint32 pos; // Unicode position in the sentence. - uint32 length; // Unicode length, not UT8 byte. - uint32 node_id; // unique id in the current lattice. - int id; // vocab id. (maybe -1 for UNK) - float score; // logprob of this sentencepiece. - float backtrace_score; // backtrace info used in Viterbi. - Node *prev; // best previous node on Viterbi path. + absl::string_view piece; // Sentence piece representation. + uint32 pos; // Unicode position in the sentence. + uint32 length; // Unicode length, not UT8 byte. + uint32 node_id; // unique id in the current lattice. + int id; // vocab id. (maybe -1 for UNK) + float score; // logprob of this sentencepiece. + float backtrace_score; // backtrace info used in Viterbi. + Node *prev; // best previous node on Viterbi path. std::string DebugString() const; }; @@ -109,17 +110,22 @@ class Lattice { std::vector<const char *> surface_; std::vector<std::vector<Node *>> begin_nodes_; std::vector<std::vector<Node *>> end_nodes_; - std::vector<Node *> all_nodes_; + model::FreeList<Node> node_allocator_; }; -// Base class for Unigram Model. -// We have base Model class because we will have different -// implementations for training and testing. -// Trie management part is shared by training and testing. -class ModelBase : public ModelInterface { +class Model : public ModelInterface { public: - ModelBase(); - ~ModelBase() override; + explicit Model(const ModelProto &model_proto); + Model() {} + ~Model() override; + + EncodeResult Encode(absl::string_view normalized) const override; + + NBestEncodeResult NBestEncode(absl::string_view normalized, + int nbest_size) const override; + + EncodeResult SampleEncode(absl::string_view normalized, + float theta) const override; // Returns the minimum score in sentence pieces. // min_score() - 10 is used for the cost of unknown sentence. @@ -150,19 +156,6 @@ class ModelBase : public ModelInterface { int trie_results_size_; }; -// Unigram model class for decoding. -class Model : public ModelBase { - public: - explicit Model(const ModelProto &model_proto); - ~Model() override; - - EncodeResult Encode(absl::string_view normalized) const override; - - NBestEncodeResult NBestEncode(absl::string_view normalized, - int nbest_size) const override; - - EncodeResult SampleEncode(absl::string_view normalized, float theta) const override; -}; } // namespace unigram } // namespace sentencepiece #endif // UNIGRAM_MODEL_H_ diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index 0974cb4..04abe17 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -81,6 +81,7 @@ class ThreadPool { TrainerModel::TrainerModel(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) : trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {} + TrainerModel::~TrainerModel() {} const TrainerModel::SentencePieces &TrainerModel::GetSentencePieces() const { @@ -92,13 +93,19 @@ void TrainerModel::SetSentencePieces(SentencePieces &&sentencepieces) { CHECK(!sentencepieces_.empty()); min_score_ = FLT_MAX; + model_proto_data_.Clear(); + model_proto_ = &model_proto_data_; std::vector<std::pair<absl::string_view, int>> pieces; + for (size_t i = 0; i < sentencepieces_.size(); ++i) { const absl::string_view w = sentencepieces_[i].first; // piece - const float score = sentencepieces_[i].second; // score. + const float score = sentencepieces_[i].second; // score. CHECK(!std::isnan(score)); pieces.emplace_back(w, i); min_score_ = std::min(min_score_, score); + auto *piece = model_proto_data_.add_pieces(); + piece->set_piece(w.data(), w.size()); + piece->set_score(score); } BuildTrie(&pieces); diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h index 2bd31d0..fa069e2 100644 --- a/src/unigram_model_trainer.h +++ b/src/unigram_model_trainer.h @@ -32,11 +32,12 @@ namespace unigram { using string_util::UnicodeText; -class TrainerModel : public ModelBase { +class TrainerModel : public Model { public: using SentencePieces = std::vector<std::pair<std::string, float>>; - TrainerModel() = delete; + TrainerModel() {} + TrainerModel(const ModelProto &model_proto) = delete; TrainerModel(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizaiton_spec); ~TrainerModel() override; @@ -49,30 +50,15 @@ class TrainerModel : public ModelBase { // The meta symbols, e.g., </s> are NOT included. void SetSentencePieces(SentencePieces &&sentencepieces); - int GetPieceSize() const override { return sentencepieces_.size(); } - - float GetScore(int index) const override { - return sentencepieces_[index].second; - } - - std::string IdToPiece(int id) const override { - return sentencepieces_[id].first; + EncodeResult Encode(absl::string_view normalized) const override { + return {}; } - bool IsControl(int id) const override { return false; } - - bool IsUnknown(int id) const override { return false; } - - bool IsUnused(int id) const override { return false; } - - bool IsUserDefined(int id) const override { return false; } - - EncodeResult Encode(absl::string_view normalized) const override { return {}; } - private: SentencePieces sentencepieces_; TrainerSpec trainer_spec_; NormalizerSpec normalizer_spec_; + ModelProto model_proto_data_; }; class Trainer : public TrainerInterface { |