Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-28 20:50:07 +0300
committerTaku Kudo <taku@google.com>2018-04-28 20:50:07 +0300
commitd16531bfb866e2fca246a36316876b934aa427f7 (patch)
tree0215e1b3555b02363b17d425b3c94200d92cb6fd /src/unigram_model.cc
parentbaf5d7a2995018ede996173cdf0febcdf23cba2d (diff)
Uses util::Status to propagate error messages
Diffstat (limited to 'src/unigram_model.cc')
-rw-r--r--src/unigram_model.cc85
1 files changed, 39 insertions, 46 deletions
diff --git a/src/unigram_model.cc b/src/unigram_model.cc
index 075a9dc..dce7c59 100644
--- a/src/unigram_model.cc
+++ b/src/unigram_model.cc
@@ -96,7 +96,6 @@ void Lattice::SetSentence(StringPiece sentence) {
Clear();
sentence_ = sentence;
- CHECK(!sentence_.empty());
const char *begin = sentence_.data();
const char *end = sentence_.data() + sentence_.size();
@@ -143,7 +142,6 @@ Lattice::Node *Lattice::Insert(int pos, int length) {
std::vector<Lattice::Node *> Lattice::Viterbi() {
const int len = size();
- CHECK_GT(len, 0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
@@ -177,10 +175,9 @@ std::vector<Lattice::Node *> Lattice::Viterbi() {
float Lattice::PopulateMarginal(float freq,
std::vector<float> *expected) const {
- CHECK_NOTNULL(expected);
+ if (expected == nullptr) return 0.0;
const int len = size();
- CHECK_GT(len, 0);
// alpha and beta (accumulative log prob) in Forward Backward.
// the index of alpha/beta is Node::node_id.
@@ -222,8 +219,10 @@ float Lattice::PopulateMarginal(float freq,
}
std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
- CHECK_GT(size(), 0);
- CHECK_GE(nbest_size, 1);
+ if (nbest_size < 1) {
+ LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
+ return {};
+ }
if (nbest_size == 1) {
return {Viterbi()};
@@ -328,13 +327,10 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
}
std::vector<Lattice::Node *> Lattice::Sample(float theta) {
- std::vector<float> alpha(all_nodes_.size(), 0.0);
- CHECK_GE(theta, 0.0);
-
const int len = size();
- CHECK_GT(len, 0);
- CHECK_GT(begin_nodes_.size(), 0);
- CHECK_GT(end_nodes_.size(), 0);
+ if (len == 0) return {};
+
+ std::vector<float> alpha(all_nodes_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
@@ -374,9 +370,6 @@ ModelBase::ModelBase() {}
ModelBase::~ModelBase() {}
void ModelBase::PopulateNodes(Lattice *lattice) const {
- CHECK_NOTNULL(lattice);
- CHECK_NOTNULL(trie_);
-
auto GetCharsLength = [](const char *begin, int len) {
const char *end = begin + len;
int result = 0;
@@ -393,9 +386,6 @@ void ModelBase::PopulateNodes(Lattice *lattice) const {
const int len = lattice->size();
const char *end = lattice->sentence() + lattice->utf8_size();
- // Initializes the buffer for return values.
- CHECK_GT(trie_results_size_, 0);
-
// +1 just in case.
std::vector<Darts::DoubleArray::result_pair_type> trie_results(
trie_results_size_ + 1);
@@ -440,9 +430,13 @@ int ModelBase::PieceToId(StringPiece piece) const {
return id == -1 ? unk_id_ : id;
}
-void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) {
- CHECK_NOTNULL(pieces);
- CHECK(!pieces->empty());
+void ModelBase::BuildTrie(std::vector<std::pair<StringPiece, int>> *pieces) {
+ if (!status().ok()) return;
+
+ if (pieces->empty()) {
+ status_ = util::InternalError("No pieces are loaded.");
+ return;
+ }
// sort by sentencepiece since DoubleArray::build()
// only accepts sorted strings.
@@ -452,14 +446,16 @@ void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) {
std::vector<const char *> key(pieces->size());
std::vector<int> value(pieces->size());
for (size_t i = 0; i < pieces->size(); ++i) {
- key[i] = (*pieces)[i].first.c_str(); // sorted piece.
- value[i] = (*pieces)[i].second; // vocab_id
+ key[i] = (*pieces)[i].first.data(); // sorted piece.
+ value[i] = (*pieces)[i].second; // vocab_id
}
trie_ = port::MakeUnique<Darts::DoubleArray>();
- CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
- &value[0]))
- << "cannot build double-array";
+ if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
+ &value[0]) != 0) {
+ status_ = util::InternalError("Cannot build double-array.");
+ return;
+ }
// Computes the maximum number of shared prefixes in the trie.
const int kMaxTrieResultsSize = 1024;
@@ -471,37 +467,35 @@ void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) {
p.first.data(), results.data(), results.size(), p.first.size());
trie_results_size_ = std::max(trie_results_size_, num_nodes);
}
- CHECK_GT(trie_results_size_, 0);
+
+ pieces_.clear();
+
+ if (trie_results_size_ == 0)
+ status_ = util::InternalError("No entry is found in the trie.");
}
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
- min_score_ = FLT_MAX;
- std::vector<std::pair<std::string, int>> pieces; // <piece, vocab_id>
- for (int i = 0; i < model_proto_->pieces_size(); ++i) {
- const auto &sp = model_proto_->pieces(i);
- CHECK(!sp.piece().empty());
+ InitializePieces(true /* use_user_defined */);
+
+ min_score_ = FLT_MAX;
+ for (const auto &sp : model_proto_->pieces()) {
if (sp.type() == ModelProto::SentencePiece::NORMAL ||
- sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
- CHECK(sp.has_score());
- pieces.emplace_back(sp.piece(), i);
- } else {
- port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
- if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
- }
- if (sp.type() == ModelProto::SentencePiece::NORMAL) {
+ sp.type() == ModelProto::SentencePiece::USER_DEFINED)
min_score_ = std::min(min_score_, sp.score());
- }
}
+ std::vector<std::pair<StringPiece, int>> pieces;
+ for (const auto &it : pieces_) pieces.emplace_back(it.first, it.second);
+
BuildTrie(&pieces);
}
Model::~Model() {}
EncodeResult Model::Encode(StringPiece normalized) const {
- if (normalized.empty()) {
+ if (!status().ok() || normalized.empty()) {
return {};
}
@@ -519,12 +513,11 @@ EncodeResult Model::Encode(StringPiece normalized) const {
NBestEncodeResult Model::NBestEncode(StringPiece normalized,
int nbest_size) const {
- if (normalized.empty()) {
+ if (!status().ok() || normalized.empty()) {
return {{{}, 0.0}};
}
- CHECK_GE(nbest_size, 1);
- CHECK_LT(nbest_size, 1024);
+ nbest_size = std::max<int>(1, std::min<int>(nbest_size, 1024));
Lattice lattice;
lattice.SetSentence(normalized);
@@ -545,7 +538,7 @@ NBestEncodeResult Model::NBestEncode(StringPiece normalized,
}
EncodeResult Model::SampleEncode(StringPiece normalized, float theta) const {
- if (normalized.empty()) {
+ if (!status().ok() || normalized.empty()) {
return {};
}