diff options
author | Taku Kudo <taku@google.com> | 2018-04-28 20:50:07 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-04-28 20:50:07 +0300 |
commit | d16531bfb866e2fca246a36316876b934aa427f7 (patch) | |
tree | 0215e1b3555b02363b17d425b3c94200d92cb6fd /src/unigram_model.cc | |
parent | baf5d7a2995018ede996173cdf0febcdf23cba2d (diff) |
Uses util::Status to propagate error messages
Diffstat (limited to 'src/unigram_model.cc')
-rw-r--r-- | src/unigram_model.cc | 85 |
1 files changed, 39 insertions, 46 deletions
diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 075a9dc..dce7c59 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -96,7 +96,6 @@ void Lattice::SetSentence(StringPiece sentence) { Clear(); sentence_ = sentence; - CHECK(!sentence_.empty()); const char *begin = sentence_.data(); const char *end = sentence_.data() + sentence_.size(); @@ -143,7 +142,6 @@ Lattice::Node *Lattice::Insert(int pos, int length) { std::vector<Lattice::Node *> Lattice::Viterbi() { const int len = size(); - CHECK_GT(len, 0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -177,10 +175,9 @@ std::vector<Lattice::Node *> Lattice::Viterbi() { float Lattice::PopulateMarginal(float freq, std::vector<float> *expected) const { - CHECK_NOTNULL(expected); + if (expected == nullptr) return 0.0; const int len = size(); - CHECK_GT(len, 0); // alpha and beta (accumulative log prob) in Forward Backward. // the index of alpha/beta is Node::node_id. @@ -222,8 +219,10 @@ float Lattice::PopulateMarginal(float freq, } std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) { - CHECK_GT(size(), 0); - CHECK_GE(nbest_size, 1); + if (nbest_size < 1) { + LOG(WARNING) << "nbest_size >= 1. Returns empty result."; + return {}; + } if (nbest_size == 1) { return {Viterbi()}; @@ -328,13 +327,10 @@ std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) { } std::vector<Lattice::Node *> Lattice::Sample(float theta) { - std::vector<float> alpha(all_nodes_.size(), 0.0); - CHECK_GE(theta, 0.0); - const int len = size(); - CHECK_GT(len, 0); - CHECK_GT(begin_nodes_.size(), 0); - CHECK_GT(end_nodes_.size(), 0); + if (len == 0) return {}; + + std::vector<float> alpha(all_nodes_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { @@ -374,9 +370,6 @@ ModelBase::ModelBase() {} ModelBase::~ModelBase() {} void ModelBase::PopulateNodes(Lattice *lattice) const { - CHECK_NOTNULL(lattice); - CHECK_NOTNULL(trie_); - auto GetCharsLength = [](const char *begin, int len) { const char *end = begin + len; int result = 0; @@ -393,9 +386,6 @@ void ModelBase::PopulateNodes(Lattice *lattice) const { const int len = lattice->size(); const char *end = lattice->sentence() + lattice->utf8_size(); - // Initializes the buffer for return values. - CHECK_GT(trie_results_size_, 0); - // +1 just in case. std::vector<Darts::DoubleArray::result_pair_type> trie_results( trie_results_size_ + 1); @@ -440,9 +430,13 @@ int ModelBase::PieceToId(StringPiece piece) const { return id == -1 ? unk_id_ : id; } -void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) { - CHECK_NOTNULL(pieces); - CHECK(!pieces->empty()); +void ModelBase::BuildTrie(std::vector<std::pair<StringPiece, int>> *pieces) { + if (!status().ok()) return; + + if (pieces->empty()) { + status_ = util::InternalError("No pieces are loaded."); + return; + } // sort by sentencepiece since DoubleArray::build() // only accepts sorted strings. @@ -452,14 +446,16 @@ void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) { std::vector<const char *> key(pieces->size()); std::vector<int> value(pieces->size()); for (size_t i = 0; i < pieces->size(); ++i) { - key[i] = (*pieces)[i].first.c_str(); // sorted piece. - value[i] = (*pieces)[i].second; // vocab_id + key[i] = (*pieces)[i].first.data(); // sorted piece. + value[i] = (*pieces)[i].second; // vocab_id } trie_ = port::MakeUnique<Darts::DoubleArray>(); - CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr, - &value[0])) - << "cannot build double-array"; + if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr, + &value[0]) != 0) { + status_ = util::InternalError("Cannot build double-array."); + return; + } // Computes the maximum number of shared prefixes in the trie. const int kMaxTrieResultsSize = 1024; @@ -471,37 +467,35 @@ void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) { p.first.data(), results.data(), results.size(), p.first.size()); trie_results_size_ = std::max(trie_results_size_, num_nodes); } - CHECK_GT(trie_results_size_, 0); + + pieces_.clear(); + + if (trie_results_size_ == 0) + status_ = util::InternalError("No entry is found in the trie."); } Model::Model(const ModelProto &model_proto) { model_proto_ = &model_proto; - min_score_ = FLT_MAX; - std::vector<std::pair<std::string, int>> pieces; // <piece, vocab_id> - for (int i = 0; i < model_proto_->pieces_size(); ++i) { - const auto &sp = model_proto_->pieces(i); - CHECK(!sp.piece().empty()); + InitializePieces(true /* use_user_defined */); + + min_score_ = FLT_MAX; + for (const auto &sp : model_proto_->pieces()) { if (sp.type() == ModelProto::SentencePiece::NORMAL || - sp.type() == ModelProto::SentencePiece::USER_DEFINED) { - CHECK(sp.has_score()); - pieces.emplace_back(sp.piece(), i); - } else { - port::InsertOrDie(&reserved_id_map_, sp.piece(), i); - if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i; - } - if (sp.type() == ModelProto::SentencePiece::NORMAL) { + sp.type() == ModelProto::SentencePiece::USER_DEFINED) min_score_ = std::min(min_score_, sp.score()); - } } + std::vector<std::pair<StringPiece, int>> pieces; + for (const auto &it : pieces_) pieces.emplace_back(it.first, it.second); + BuildTrie(&pieces); } Model::~Model() {} EncodeResult Model::Encode(StringPiece normalized) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {}; } @@ -519,12 +513,11 @@ EncodeResult Model::Encode(StringPiece normalized) const { NBestEncodeResult Model::NBestEncode(StringPiece normalized, int nbest_size) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {{{}, 0.0}}; } - CHECK_GE(nbest_size, 1); - CHECK_LT(nbest_size, 1024); + nbest_size = std::max<int>(1, std::min<int>(nbest_size, 1024)); Lattice lattice; lattice.SetSentence(normalized); @@ -545,7 +538,7 @@ NBestEncodeResult Model::NBestEncode(StringPiece normalized, } EncodeResult Model::SampleEncode(StringPiece normalized, float theta) const { - if (normalized.empty()) { + if (!status().ok() || normalized.empty()) { return {}; } |