From d16531bfb866e2fca246a36316876b934aa427f7 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sun, 29 Apr 2018 02:50:07 +0900 Subject: Uses util::Status to propagate error messages --- src/unigram_model.cc | 85 ++++++++++++++++++++++++---------------------------- 1 file changed, 39 insertions(+), 46 deletions(-) (limited to 'src/unigram_model.cc') diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 075a9dc..dce7c59 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -96,7 +96,6 @@ void Lattice::SetSentence(StringPiece sentence) { Clear(); sentence_ = sentence; - CHECK(!sentence_.empty()); const char *begin = sentence_.data(); const char *end = sentence_.data() + sentence_.size(); @@ -143,7 +142,6 @@ Lattice::Node *Lattice::Insert(int pos, int length) { std::vector Lattice::Viterbi() { const int len = size(); - CHECK_GT(len, 0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -177,10 +175,9 @@ std::vector Lattice::Viterbi() { float Lattice::PopulateMarginal(float freq, std::vector *expected) const { - CHECK_NOTNULL(expected); + if (expected == nullptr) return 0.0; const int len = size(); - CHECK_GT(len, 0); // alpha and beta (accumulative log prob) in Forward Backward. // the index of alpha/beta is Node::node_id. @@ -222,8 +219,10 @@ float Lattice::PopulateMarginal(float freq, } std::vector> Lattice::NBest(size_t nbest_size) { - CHECK_GT(size(), 0); - CHECK_GE(nbest_size, 1); + if (nbest_size < 1) { + LOG(WARNING) << "nbest_size >= 1. Returns empty result."; + return {}; + } if (nbest_size == 1) { return {Viterbi()}; @@ -328,13 +327,10 @@ std::vector> Lattice::NBest(size_t nbest_size) { } std::vector Lattice::Sample(float theta) { - std::vector alpha(all_nodes_.size(), 0.0); - CHECK_GE(theta, 0.0); - const int len = size(); - CHECK_GT(len, 0); - CHECK_GT(begin_nodes_.size(), 0); - CHECK_GT(end_nodes_.size(), 0); + if (len == 0) return {}; + + std::vector alpha(all_nodes_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -374,9 +370,6 @@ ModelBase::ModelBase() {} ModelBase::~ModelBase() {} void ModelBase::PopulateNodes(Lattice *lattice) const { - CHECK_NOTNULL(lattice); - CHECK_NOTNULL(trie_); - auto GetCharsLength = [](const char *begin, int len) { const char *end = begin + len; int result = 0; @@ -393,9 +386,6 @@ void ModelBase::PopulateNodes(Lattice *lattice) const { const int len = lattice->size(); const char *end = lattice->sentence() + lattice->utf8_size(); - // Initializes the buffer for return values. - CHECK_GT(trie_results_size_, 0); - // +1 just in case. std::vector trie_results( trie_results_size_ + 1); @@ -440,9 +430,13 @@ int ModelBase::PieceToId(StringPiece piece) const { return id == -1 ? unk_id_ : id; } -void ModelBase::BuildTrie(std::vector> *pieces) { - CHECK_NOTNULL(pieces); - CHECK(!pieces->empty()); +void ModelBase::BuildTrie(std::vector> *pieces) { + if (!status().ok()) return; + + if (pieces->empty()) { + status_ = util::InternalError("No pieces are loaded."); + return; + } // sort by sentencepiece since DoubleArray::build() // only accepts sorted strings. @@ -452,14 +446,16 @@ void ModelBase::BuildTrie(std::vector> *pieces) { std::vector key(pieces->size()); std::vector value(pieces->size()); for (size_t i = 0; i < pieces->size(); ++i) { - key[i] = (*pieces)[i].first.c_str(); // sorted piece. - value[i] = (*pieces)[i].second; // vocab_id + key[i] = (*pieces)[i].first.data(); // sorted piece. + value[i] = (*pieces)[i].second; // vocab_id } trie_ = port::MakeUnique(); - CHECK_EQ(0, trie_->build(key.size(), const_cast(&key[0]), nullptr, - &value[0])) - << "cannot build double-array"; + if (trie_->build(key.size(), const_cast(&key[0]), nullptr, + &value[0]) != 0) { + status_ = util::InternalError("Cannot build double-array."); + return; + } // Computes the maximum number of shared prefixes in the trie. const int kMaxTrieResultsSize = 1024; @@ -471,37 +467,35 @@ void ModelBase::BuildTrie(std::vector> *pieces) { p.first.data(), results.data(), results.size(), p.first.size()); trie_results_size_ = std::max(trie_results_size_, num_nodes); } - CHECK_GT(trie_results_size_, 0); + + pieces_.clear(); + + if (trie_results_size_ == 0) + status_ = util::InternalError("No entry is found in the trie."); } Model::Model(const ModelProto &model_proto) { model_proto_ = &model_proto; - min_score_ = FLT_MAX; - std::vector> pieces; // - for (int i = 0; i < model_proto_->pieces_size(); ++i) { - const auto &sp = model_proto_->pieces(i); - CHECK(!sp.piece().empty()); + InitializePieces(true /* use_user_defined */); + + min_score_ = FLT_MAX; + for (const auto &sp : model_proto_->pieces()) { if (sp.type() == ModelProto::SentencePiece::NORMAL || - sp.type() == ModelProto::SentencePiece::USER_DEFINED) { - CHECK(sp.has_score()); - pieces.emplace_back(sp.piece(), i); - } else { - port::InsertOrDie(&reserved_id_map_, sp.piece(), i); - if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i; - } - if (sp.type() == ModelProto::SentencePiece::NORMAL) { + sp.type() == ModelProto::SentencePiece::USER_DEFINED) min_score_ = std::min(min_score_, sp.score()); - } } + std::vector> pieces; + for (const auto &it : pieces_) pieces.emplace_back(it.first, it.second); + BuildTrie(&pieces); } Model::~Model() {} EncodeResult Model::Encode(StringPiece normalized) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {}; } @@ -519,12 +513,11 @@ EncodeResult Model::Encode(StringPiece normalized) const { NBestEncodeResult Model::NBestEncode(StringPiece normalized, int nbest_size) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {{{}, 0.0}}; } - CHECK_GE(nbest_size, 1); - CHECK_LT(nbest_size, 1024); + nbest_size = std::max(1, std::min(nbest_size, 1024)); Lattice lattice; lattice.SetSentence(normalized); @@ -545,7 +538,7 @@ NBestEncodeResult Model::NBestEncode(StringPiece normalized, } EncodeResult Model::SampleEncode(StringPiece normalized, float theta) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {}; } -- cgit v1.2.3