Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-09-16 11:40:16 +0300
committerTaku Kudo <taku@google.com>2018-09-16 11:40:16 +0300
commitb66b41641385b592a1a20431cbf4a775466a79a2 (patch)
tree79c67b7421f2fa1eb16f32595def3a795182d8d6
parentc14d581837868d67b17abfc2b8b382c1c9c3660d (diff)
performance tuningsr
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/bpe_model.cc17
-rw-r--r--src/freelist.h82
-rw-r--r--src/freelist_test.cc34
-rw-r--r--src/model_interface.cc29
-rw-r--r--src/model_interface.h62
-rw-r--r--src/sentencepiece_processor.cc5
-rw-r--r--src/sentencepiece_processor.h2
-rw-r--r--src/sentencepiece_processor_test.cc5
-rw-r--r--src/unigram_model.cc75
-rw-r--r--src/unigram_model.h51
-rw-r--r--src/unigram_model_trainer.cc9
-rw-r--r--src/unigram_model_trainer.h26
13 files changed, 256 insertions, 142 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 96be210..ebfcaa6 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -29,6 +29,7 @@ set(SPM_SRCS
common.h
normalizer.h
util.h
+ freelist.h
filesystem.h
flags.h
sentencepiece_processor.h
diff --git a/src/bpe_model.cc b/src/bpe_model.cc
index e97fb67..26c98ed 100644
--- a/src/bpe_model.cc
+++ b/src/bpe_model.cc
@@ -20,6 +20,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
+#include "freelist.h"
#include "util.h"
namespace sentencepiece {
@@ -73,9 +74,13 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode(
string_util::string_view_hash>
rev_merge;
+ // Pre-allocates SymbolPair for efficiency.
+ constexpr size_t kPreallocateSymbolPairSize = 256;
+ model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize);
+
// Lookup new symbol pair at [left, right] and inserts it to agenda.
- auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge](
- int left, int right) {
+ auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda,
+ &rev_merge](int left, int right) {
if (left == -1 || right == -1 || symbols[left].freeze ||
symbols[right].freeze)
return;
@@ -86,7 +91,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode(
if (it == pieces_.end()) {
return;
}
- auto *h = new SymbolPair;
+ auto *h = symbol_pair_allocator.Allocate();
h->left = left;
h->right = right;
h->score = GetScore(it->second);
@@ -94,7 +99,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode(
agenda.push(h);
// Makes `rev_merge` for resegmentation.
- if (IsUnused(it->second)) {
+ if (IsUnusedInlined(it->second)) {
rev_merge[piece] =
std::make_pair(symbols[left].piece, symbols[right].piece);
}
@@ -124,7 +129,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode(
// Main loop.
while (!agenda.empty()) {
- std::unique_ptr<SymbolPair> top(agenda.top());
+ SymbolPair *top = agenda.top();
agenda.pop();
// `top` is no longer available.
@@ -155,7 +160,7 @@ std::vector<std::pair<absl::string_view, int>> Model::Encode(
resegment = [this, &resegment, &rev_merge](absl::string_view w,
EncodeResult *output) -> void {
const int id = PieceToId(w);
- if (id == -1 || !IsUnused(id)) {
+ if (id == -1 || !IsUnusedInlined(id)) {
output->emplace_back(w, id);
return;
}
diff --git a/src/freelist.h b/src/freelist.h
new file mode 100644
index 0000000..e39c338
--- /dev/null
+++ b/src/freelist.h
@@ -0,0 +1,82 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+#ifndef FREELIST_H_
+#define FREELIST_H_
+
+#include <string.h>
+#include <vector>
+
+namespace sentencepiece {
+namespace model {
+
+// Simple FreeList that allocates a chunk of T at once.
+template <class T>
+class FreeList {
+ public:
+ FreeList() = delete;
+ explicit FreeList(size_t chunk_size) : chunk_size_(chunk_size) {}
+ virtual ~FreeList() {
+ for (auto& chunk : freelist_) delete[] chunk;
+ }
+
+ // `Free` doesn't free the object but reuse the allocated memory chunks.
+ void Free() {
+ const int size = std::min<int>(chunk_index_ + 1, freelist_.size());
+ for (int i = 0; i < size; ++i) {
+ T* chunk = freelist_[i];
+ memset(chunk, 0, sizeof(*chunk) * chunk_size_);
+ }
+ chunk_index_ = 0;
+ element_index_ = 0;
+ }
+
+ // Returns the number of allocated elements.
+ size_t size() const { return chunk_size_ * chunk_index_ + element_index_; }
+
+ // Returns the element as an array.
+ T* operator[](size_t index) const {
+ return freelist_[index / chunk_size_] + index % chunk_size_;
+ }
+
+ // Allocates new element.
+ T* Allocate() {
+ if (element_index_ >= chunk_size_) {
+ ++chunk_index_;
+ element_index_ = 0;
+ }
+
+ if (chunk_index_ == freelist_.size()) {
+ T* chunk = new T[chunk_size_];
+ memset(chunk, 0, sizeof(*chunk) * chunk_size_);
+ freelist_.push_back(chunk);
+ }
+
+ T* result = freelist_[chunk_index_] + element_index_;
+ ++element_index_;
+
+ return result;
+ }
+
+ private:
+ std::vector<T*> freelist_;
+
+ // The last element is stored at freelist_[chunk_index_][element_index_]
+ size_t element_index_ = 0;
+ size_t chunk_index_ = 0;
+ const size_t chunk_size_ = 0;
+};
+} // namespace model
+} // namespace sentencepiece
+#endif // FREELIST_H_
diff --git a/src/freelist_test.cc b/src/freelist_test.cc
new file mode 100644
index 0000000..a7ff7de
--- /dev/null
+++ b/src/freelist_test.cc
@@ -0,0 +1,34 @@
+#include "freelist.h"
+#include "testharness.h"
+
+namespace sentencepiece {
+namespace model {
+
+TEST(FreeListTest, BasicTest) {
+ FreeList<int> l(5);
+ EXPECT_EQ(0, l.size());
+
+ constexpr size_t kSize = 32;
+
+ for (size_t i = 0; i < kSize; ++i) {
+ int *n = l.Allocate();
+ EXPECT_EQ(0, *n);
+ *n = i;
+ }
+
+ EXPECT_EQ(kSize, l.size());
+ for (size_t i = 0; i < kSize; ++i) {
+ EXPECT_EQ(i, *l[i]);
+ }
+
+ l.Free();
+ EXPECT_EQ(0, l.size());
+
+ // Zero-initialized after `Free`.
+ for (size_t i = 0; i < kSize; ++i) {
+ int *n = l.Allocate();
+ EXPECT_EQ(0, *n);
+ }
+}
+} // namespace model
+} // namespace sentencepiece
diff --git a/src/model_interface.cc b/src/model_interface.cc
index 0a99e02..3b21388 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -91,35 +91,6 @@ int ModelInterface::PieceToId(absl::string_view piece) const {
return unk_id_;
}
-int ModelInterface::GetPieceSize() const { return model_proto_->pieces_size(); }
-
-std::string ModelInterface::IdToPiece(int id) const {
- return model_proto_->pieces(id).piece();
-}
-
-float ModelInterface::GetScore(int id) const {
- return model_proto_->pieces(id).score();
-}
-
-bool ModelInterface::IsControl(int id) const {
- return (model_proto_->pieces(id).type() ==
- ModelProto::SentencePiece::CONTROL);
-}
-
-bool ModelInterface::IsUnknown(int id) const {
- return (model_proto_->pieces(id).type() ==
- ModelProto::SentencePiece::UNKNOWN);
-}
-
-bool ModelInterface::IsUnused(int id) const {
- return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNUSED);
-}
-
-bool ModelInterface::IsUserDefined(int id) const {
- return (model_proto_->pieces(id).type() ==
- ModelProto::SentencePiece::USER_DEFINED);
-}
-
void ModelInterface::InitializePieces(bool use_prefix_matcher) {
pieces_.clear();
reserved_id_map_.clear();
diff --git a/src/model_interface.h b/src/model_interface.h
index a7a6350..2e3f670 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -23,6 +23,7 @@
#include <vector>
#include "common.h"
+#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "third_party/absl/strings/string_view.h"
#include "third_party/darts_clone/darts.h"
@@ -99,38 +100,79 @@ class ModelInterface {
return EncodeResult();
}
- // Returns the size of sentence pieces, which is the same
- // as the size of vocabulary for NMT.
- virtual int GetPieceSize() const;
-
// Returns the vocab id of `piece`.
// Returns UNK(0) if `piece` is unknown
virtual int PieceToId(absl::string_view piece) const;
// Returns the string representation of vocab with `id`.
// id must be 0 <= id < GetPieceSize().
- virtual std::string IdToPiece(int id) const;
+ virtual const std::string &IdToPiece(int id) const {
+ return model_proto_->pieces(id).piece();
+ }
+
+ // Returns the size of sentence pieces, which is the same
+ // as the size of vocabulary for NMT.
+ virtual int GetPieceSize() const { return model_proto_->pieces_size(); }
// Returns the score of `id`.
// Score represents a log probability of the piece.
// We can roughly estimate the unigram frequency of the piece.
- virtual float GetScore(int id) const;
+ virtual float GetScore(int id) const {
+ return model_proto_->pieces(id).score();
+ }
// Returns true if `id` is unknown symbol.
- virtual bool IsUnknown(int id) const;
+ virtual bool IsUnknown(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::UNKNOWN);
+ }
// Returns true if `id` is control symbol.
- virtual bool IsControl(int id) const;
+ virtual bool IsControl(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::CONTROL);
+ }
// Returns true if `id` is unused symbol.
- virtual bool IsUnused(int id) const;
+ virtual bool IsUnused(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::UNUSED);
+ }
// Returns true if `id` is user defined symbol.
- virtual bool IsUserDefined(int id) const;
+ virtual bool IsUserDefined(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::USER_DEFINED);
+ }
protected:
void InitializePieces(bool use_prefix_matcher);
+ // Non-virtual (inlined) implementation for faster execution.
+ inline float GetScoreInlined(int id) const {
+ return model_proto_->pieces(id).score();
+ }
+
+ inline bool IsUnknownInlined(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::UNKNOWN);
+ }
+
+ inline bool IsControlInlined(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::CONTROL);
+ }
+
+ inline bool IsUnusedInlined(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::UNUSED);
+ }
+
+ inline bool IsUserDefinedInlined(int id) const {
+ return (model_proto_->pieces(id).type() ==
+ ModelProto::SentencePiece::USER_DEFINED);
+ }
+
const ModelProto *model_proto_ = nullptr;
// PrefixMatcher for user defined symbols.
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 9bd41a0..f2f8967 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -521,8 +521,9 @@ int SentencePieceProcessor::PieceToId(util::min_string_view piece) const {
return model_->PieceToId(string_util::ToSV(piece));
}
-std::string SentencePieceProcessor::IdToPiece(int id) const {
- CHECK_STATUS_OR_RETURN_DEFAULT("");
+const std::string &SentencePieceProcessor::IdToPiece(int id) const {
+ static const std::string *kEmptyString = new std::string;
+ CHECK_STATUS_OR_RETURN_DEFAULT(*kEmptyString);
return model_->IdToPiece(id);
}
diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h
index 1cd6c54..61da691 100644
--- a/src/sentencepiece_processor.h
+++ b/src/sentencepiece_processor.h
@@ -369,7 +369,7 @@ class SentencePieceProcessor {
virtual int PieceToId(util::min_string_view piece) const;
// Returns the string representation of vocab with `id`.
- virtual std::string IdToPiece(int id) const;
+ virtual const std::string &IdToPiece(int id) const;
// Returns the score of `id`.
// Usually score is an emission log probability of unigram language model.
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 793826d..c1ef138 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -71,7 +71,7 @@ class MockModel : public ModelInterface {
int PieceToId(absl::string_view piece) const { return 0; }
- std::string IdToPiece(int id) const { return ""; }
+ const std::string &IdToPiece(int id) const { return kEmptyString; }
float GetScore(int id) const { return 0.0; }
@@ -79,6 +79,7 @@ class MockModel : public ModelInterface {
absl::string_view input_;
EncodeResult output_;
NBestEncodeResult nbest_output_;
+ const std::string kEmptyString;
};
std::vector<std::string> GetSpVec(const EncodeResult &pieces) {
@@ -457,7 +458,7 @@ TEST(SentencepieceProcessorTest, DecodeTest) {
return port::FindWithDefault(kMap, piece, 0);
}
- std::string IdToPiece(int id) const override {
+ const std::string &IdToPiece(int id) const override {
static std::vector<std::string> kMap = {
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H"};
return kMap[id];
diff --git a/src/unigram_model.cc b/src/unigram_model.cc
index b7e2d02..280dec7 100644
--- a/src/unigram_model.cc
+++ b/src/unigram_model.cc
@@ -30,6 +30,9 @@ namespace sentencepiece {
namespace unigram {
namespace {
+// Size of nodes pre-allocated in Lattice.
+constexpr size_t kPreallocateLatticeNodeSize = 1024;
+
// Returns log(exp(x) + exp(y)).
// if init_mode is true, returns log(exp(y)) == y.
// log(\sum_i exp(a[i])) can be computed as
@@ -50,8 +53,8 @@ inline float LogSumExp(float x, float y, bool init_mode) {
}
} // namespace
-Lattice::Lattice() {}
-Lattice::~Lattice() { Clear(); }
+Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {}
+Lattice::~Lattice() {}
const std::vector<Lattice::Node *> &Lattice::begin_nodes(int pos) const {
return begin_nodes_[pos];
@@ -77,10 +80,8 @@ Lattice::Node *Lattice::bos_node() const { return end_nodes_[0][0]; }
Lattice::Node *Lattice::eos_node() const { return begin_nodes_[size()][0]; }
Lattice::Node *Lattice::NewNode() {
- Node *node = new Node;
- memset(node, 0, sizeof(*node));
- node->node_id = all_nodes_.size();
- all_nodes_.push_back(node);
+ Node *node = node_allocator_.Allocate();
+ node->node_id = node_allocator_.size() - 1;
return node;
}
@@ -89,14 +90,14 @@ void Lattice::Clear() {
end_nodes_.clear();
sentence_ = absl::string_view("");
surface_.clear();
- port::STLDeleteElements(&all_nodes_);
- all_nodes_.clear();
+ node_allocator_.Free();
}
void Lattice::SetSentence(absl::string_view sentence) {
Clear();
sentence_ = sentence;
+ surface_.reserve(sentence.size() + 1);
while (!sentence.empty()) {
const int mblen = std::min<int>(string_util::OneCharLen(sentence.data()),
@@ -110,9 +111,10 @@ void Lattice::SetSentence(absl::string_view sentence) {
begin_nodes_.resize(len + 1);
end_nodes_.resize(len + 1);
+ constexpr size_t kReservedNodeSize = 16;
for (int i = 0; i <= len; ++i) {
- begin_nodes_[i].reserve(16);
- end_nodes_[i].reserve(16);
+ begin_nodes_[i].reserve(kReservedNodeSize);
+ end_nodes_[i].reserve(kReservedNodeSize);
}
Node *bos = NewNode();
@@ -183,8 +185,8 @@ float Lattice::PopulateMarginal(float freq,
// alpha and beta (accumulative log prob) in Forward Backward.
// the index of alpha/beta is Node::node_id.
- std::vector<float> alpha(all_nodes_.size(), 0.0);
- std::vector<float> beta(all_nodes_.size(), 0.0);
+ std::vector<float> alpha(node_allocator_.size(), 0.0);
+ std::vector<float> beta(node_allocator_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
@@ -256,19 +258,13 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
using Agenda = std::priority_queue<Hypothesis *, std::vector<Hypothesis *>,
HypothesisComparator>;
+ constexpr size_t kPreallocatedHypothesisSize = 512;
+ model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize);
Agenda agenda;
- std::vector<Hypothesis *> allocated;
std::vector<std::vector<Node *>> results;
- auto NewHypothesis = [&allocated]() {
- Hypothesis *h = new Hypothesis;
- memset(h, 0, sizeof(*h));
- allocated.push_back(h);
- return h;
- };
-
- auto *eos = NewHypothesis();
+ auto *eos = hypothesis_allocator.Allocate();
eos->node = eos_node();
eos->next = nullptr;
eos->fx = eos->node->score;
@@ -297,7 +293,7 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
// Expands new node ending at node->pos
for (Node *lnode : end_nodes(node->pos)) {
- auto *hyp = NewHypothesis();
+ auto *hyp = hypothesis_allocator.Allocate();
hyp->node = lnode;
hyp->gx = lnode->score + top->gx; // just adds node->score
hyp->fx =
@@ -324,7 +320,6 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
}
}
- port::STLDeleteElements(&allocated);
return results;
}
@@ -332,7 +327,7 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
const int len = size();
if (len == 0) return {};
- std::vector<float> alpha(all_nodes_.size(), 0.0);
+ std::vector<float> alpha(node_allocator_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
@@ -368,18 +363,14 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
return results;
}
-ModelBase::ModelBase() {}
-ModelBase::~ModelBase() {}
+// Model::Model() {}
+// Model::~Model() {}
-void ModelBase::PopulateNodes(Lattice *lattice) const {
- auto GetCharsLength = [](const char *begin, int len) {
- const char *end = begin + len;
- int result = 0;
- while (begin < end) {
- begin += std::min<int>(string_util::OneCharLen(begin), end - begin);
- ++result;
- }
- return result;
+void Model::PopulateNodes(Lattice *lattice) const {
+ auto get_chars_length = [&lattice](int begin_pos, const char *end) {
+ int pos = begin_pos;
+ while (lattice->surface(pos) < end) ++pos;
+ return pos - begin_pos;
};
constexpr float kUnkPenalty = 10.0;
@@ -405,14 +396,15 @@ void ModelBase::PopulateNodes(Lattice *lattice) const {
// Inserts pieces to the lattice.
for (size_t k = 0; k < num_nodes; ++k) {
- const int length = GetCharsLength(begin, trie_results[k].length);
+ const int length =
+ get_chars_length(begin_pos, begin + trie_results[k].length);
const int id = trie_results[k].value;
- if (IsUnused(id)) continue;
+ if (IsUnusedInlined(id)) continue;
Lattice::Node *node = lattice->Insert(begin_pos, length);
node->id = id; // the value of Trie stores vocab_id.
// User defined symbol receives extra bonus to always be selected.
- node->score =
- IsUserDefined(id) ? (length * max_score_ + 1.0) : GetScore(id);
+ node->score = IsUserDefinedInlined(id) ? (length * max_score_ + 1.0)
+ : GetScoreInlined(id);
if (!has_single_node && node->length == 1) {
has_single_node = true;
}
@@ -426,7 +418,7 @@ void ModelBase::PopulateNodes(Lattice *lattice) const {
}
}
-int ModelBase::PieceToId(absl::string_view piece) const {
+int Model::PieceToId(absl::string_view piece) const {
auto it = reserved_id_map_.find(piece);
if (it != reserved_id_map_.end()) {
return it->second;
@@ -436,8 +428,7 @@ int ModelBase::PieceToId(absl::string_view piece) const {
return id == -1 ? unk_id_ : id;
}
-void ModelBase::BuildTrie(
- std::vector<std::pair<absl::string_view, int>> *pieces) {
+void Model::BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces) {
if (!status().ok()) return;
if (pieces->empty()) {
diff --git a/src/unigram_model.h b/src/unigram_model.h
index aee61b6..466a1c2 100644
--- a/src/unigram_model.h
+++ b/src/unigram_model.h
@@ -21,6 +21,7 @@
#include <vector>
#include "common.h"
+#include "freelist.h"
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
#include "third_party/darts_clone/darts.h"
@@ -35,14 +36,14 @@ class Lattice {
virtual ~Lattice();
struct Node {
- absl::string_view piece; // Sentence piece representation.
- uint32 pos; // Unicode position in the sentence.
- uint32 length; // Unicode length, not UT8 byte.
- uint32 node_id; // unique id in the current lattice.
- int id; // vocab id. (maybe -1 for UNK)
- float score; // logprob of this sentencepiece.
- float backtrace_score; // backtrace info used in Viterbi.
- Node *prev; // best previous node on Viterbi path.
+ absl::string_view piece; // Sentence piece representation.
+ uint32 pos; // Unicode position in the sentence.
+ uint32 length; // Unicode length, not UT8 byte.
+ uint32 node_id; // unique id in the current lattice.
+ int id; // vocab id. (maybe -1 for UNK)
+ float score; // logprob of this sentencepiece.
+ float backtrace_score; // backtrace info used in Viterbi.
+ Node *prev; // best previous node on Viterbi path.
std::string DebugString() const;
};
@@ -109,17 +110,22 @@ class Lattice {
std::vector<const char *> surface_;
std::vector<std::vector<Node *>> begin_nodes_;
std::vector<std::vector<Node *>> end_nodes_;
- std::vector<Node *> all_nodes_;
+ model::FreeList<Node> node_allocator_;
};
-// Base class for Unigram Model.
-// We have base Model class because we will have different
-// implementations for training and testing.
-// Trie management part is shared by training and testing.
-class ModelBase : public ModelInterface {
+class Model : public ModelInterface {
public:
- ModelBase();
- ~ModelBase() override;
+ explicit Model(const ModelProto &model_proto);
+ Model() {}
+ ~Model() override;
+
+ EncodeResult Encode(absl::string_view normalized) const override;
+
+ NBestEncodeResult NBestEncode(absl::string_view normalized,
+ int nbest_size) const override;
+
+ EncodeResult SampleEncode(absl::string_view normalized,
+ float theta) const override;
// Returns the minimum score in sentence pieces.
// min_score() - 10 is used for the cost of unknown sentence.
@@ -150,19 +156,6 @@ class ModelBase : public ModelInterface {
int trie_results_size_;
};
-// Unigram model class for decoding.
-class Model : public ModelBase {
- public:
- explicit Model(const ModelProto &model_proto);
- ~Model() override;
-
- EncodeResult Encode(absl::string_view normalized) const override;
-
- NBestEncodeResult NBestEncode(absl::string_view normalized,
- int nbest_size) const override;
-
- EncodeResult SampleEncode(absl::string_view normalized, float theta) const override;
-};
} // namespace unigram
} // namespace sentencepiece
#endif // UNIGRAM_MODEL_H_
diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc
index 0974cb4..04abe17 100644
--- a/src/unigram_model_trainer.cc
+++ b/src/unigram_model_trainer.cc
@@ -81,6 +81,7 @@ class ThreadPool {
TrainerModel::TrainerModel(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {}
+
TrainerModel::~TrainerModel() {}
const TrainerModel::SentencePieces &TrainerModel::GetSentencePieces() const {
@@ -92,13 +93,19 @@ void TrainerModel::SetSentencePieces(SentencePieces &&sentencepieces) {
CHECK(!sentencepieces_.empty());
min_score_ = FLT_MAX;
+ model_proto_data_.Clear();
+ model_proto_ = &model_proto_data_;
std::vector<std::pair<absl::string_view, int>> pieces;
+
for (size_t i = 0; i < sentencepieces_.size(); ++i) {
const absl::string_view w = sentencepieces_[i].first; // piece
- const float score = sentencepieces_[i].second; // score.
+ const float score = sentencepieces_[i].second; // score.
CHECK(!std::isnan(score));
pieces.emplace_back(w, i);
min_score_ = std::min(min_score_, score);
+ auto *piece = model_proto_data_.add_pieces();
+ piece->set_piece(w.data(), w.size());
+ piece->set_score(score);
}
BuildTrie(&pieces);
diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h
index 2bd31d0..fa069e2 100644
--- a/src/unigram_model_trainer.h
+++ b/src/unigram_model_trainer.h
@@ -32,11 +32,12 @@ namespace unigram {
using string_util::UnicodeText;
-class TrainerModel : public ModelBase {
+class TrainerModel : public Model {
public:
using SentencePieces = std::vector<std::pair<std::string, float>>;
- TrainerModel() = delete;
+ TrainerModel() {}
+ TrainerModel(const ModelProto &model_proto) = delete;
TrainerModel(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizaiton_spec);
~TrainerModel() override;
@@ -49,30 +50,15 @@ class TrainerModel : public ModelBase {
// The meta symbols, e.g., </s> are NOT included.
void SetSentencePieces(SentencePieces &&sentencepieces);
- int GetPieceSize() const override { return sentencepieces_.size(); }
-
- float GetScore(int index) const override {
- return sentencepieces_[index].second;
- }
-
- std::string IdToPiece(int id) const override {
- return sentencepieces_[id].first;
+ EncodeResult Encode(absl::string_view normalized) const override {
+ return {};
}
- bool IsControl(int id) const override { return false; }
-
- bool IsUnknown(int id) const override { return false; }
-
- bool IsUnused(int id) const override { return false; }
-
- bool IsUserDefined(int id) const override { return false; }
-
- EncodeResult Encode(absl::string_view normalized) const override { return {}; }
-
private:
SentencePieces sentencepieces_;
TrainerSpec trainer_spec_;
NormalizerSpec normalizer_spec_;
+ ModelProto model_proto_data_;
};
class Trainer : public TrainerInterface {