diff options
author | Taku Kudo <taku@google.com> | 2018-04-28 20:50:07 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-04-28 20:50:07 +0300 |
commit | d16531bfb866e2fca246a36316876b934aa427f7 (patch) | |
tree | 0215e1b3555b02363b17d425b3c94200d92cb6fd /src/model_interface.cc | |
parent | baf5d7a2995018ede996173cdf0febcdf23cba2d (diff) |
Uses util::Status to propagate error messages
Diffstat (limited to 'src/model_interface.cc')
-rw-r--r-- | src/model_interface.cc | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc index d4602ea..059e8bf 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -19,7 +19,7 @@ namespace sentencepiece { ModelInterface::ModelInterface(const ModelProto &model_proto) - : model_proto_(&model_proto) {} + : model_proto_(&model_proto), status_(util::OkStatus()) {} ModelInterface::~ModelInterface() {} int ModelInterface::PieceToId(StringPiece piece) const { @@ -34,28 +34,52 @@ int ModelInterface::PieceToId(StringPiece piece) const { return unk_id_; } -int ModelInterface::GetPieceSize() const { - return CHECK_NOTNULL(model_proto_)->pieces_size(); -} +int ModelInterface::GetPieceSize() const { return model_proto_->pieces_size(); } std::string ModelInterface::IdToPiece(int id) const { - return CHECK_NOTNULL(model_proto_)->pieces(id).piece(); + return model_proto_->pieces(id).piece(); } float ModelInterface::GetScore(int id) const { - return CHECK_NOTNULL(model_proto_)->pieces(id).score(); + return model_proto_->pieces(id).score(); } bool ModelInterface::IsControl(int id) const { - return (CHECK_NOTNULL(model_proto_)->pieces(id).type() == + return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::CONTROL); } bool ModelInterface::IsUnknown(int id) const { - return (CHECK_NOTNULL(model_proto_)->pieces(id).type() == + return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNKNOWN); } +void ModelInterface::InitializePieces(bool enable_user_defined) { + pieces_.clear(); + reserved_id_map_.clear(); + unk_id_ = 0; + + for (int i = 0; i < model_proto_->pieces_size(); ++i) { + const auto &sp = model_proto_->pieces(i); + if (!enable_user_defined && + sp.type() == ModelProto::SentencePiece::USER_DEFINED) { + status_ = util::InternalError("User defined symbol is not supported."); + return; + } + + const bool is_normal_piece = + (sp.type() == ModelProto::SentencePiece::NORMAL || + sp.type() == ModelProto::SentencePiece::USER_DEFINED); + if (!port::InsertIfNotPresent( + is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) { + status_ = util::InternalError(sp.piece() + " is already defined."); + return; + } + + if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i; + } +} + std::vector<StringPiece> SplitIntoWords(StringPiece text) { const char *begin = text.data(); const char *end = text.data() + text.size(); |