Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-04-28 20:50:07 +0300
committerTaku Kudo <taku@google.com>2018-04-28 20:50:07 +0300
commitd16531bfb866e2fca246a36316876b934aa427f7 (patch)
tree0215e1b3555b02363b17d425b3c94200d92cb6fd /src/model_interface.cc
parentbaf5d7a2995018ede996173cdf0febcdf23cba2d (diff)
Uses util::Status to propagate error messages
Diffstat (limited to 'src/model_interface.cc')
-rw-r--r--src/model_interface.cc40
1 files changed, 32 insertions, 8 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc
index d4602ea..059e8bf 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -19,7 +19,7 @@
namespace sentencepiece {
ModelInterface::ModelInterface(const ModelProto &model_proto)
- : model_proto_(&model_proto) {}
+ : model_proto_(&model_proto), status_(util::OkStatus()) {}
ModelInterface::~ModelInterface() {}
int ModelInterface::PieceToId(StringPiece piece) const {
@@ -34,28 +34,52 @@ int ModelInterface::PieceToId(StringPiece piece) const {
return unk_id_;
}
-int ModelInterface::GetPieceSize() const {
- return CHECK_NOTNULL(model_proto_)->pieces_size();
-}
+int ModelInterface::GetPieceSize() const { return model_proto_->pieces_size(); }
std::string ModelInterface::IdToPiece(int id) const {
- return CHECK_NOTNULL(model_proto_)->pieces(id).piece();
+ return model_proto_->pieces(id).piece();
}
float ModelInterface::GetScore(int id) const {
- return CHECK_NOTNULL(model_proto_)->pieces(id).score();
+ return model_proto_->pieces(id).score();
}
bool ModelInterface::IsControl(int id) const {
- return (CHECK_NOTNULL(model_proto_)->pieces(id).type() ==
+ return (model_proto_->pieces(id).type() ==
ModelProto::SentencePiece::CONTROL);
}
bool ModelInterface::IsUnknown(int id) const {
- return (CHECK_NOTNULL(model_proto_)->pieces(id).type() ==
+ return (model_proto_->pieces(id).type() ==
ModelProto::SentencePiece::UNKNOWN);
}
+void ModelInterface::InitializePieces(bool enable_user_defined) {
+ pieces_.clear();
+ reserved_id_map_.clear();
+ unk_id_ = 0;
+
+ for (int i = 0; i < model_proto_->pieces_size(); ++i) {
+ const auto &sp = model_proto_->pieces(i);
+ if (!enable_user_defined &&
+ sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
+ status_ = util::InternalError("User defined symbol is not supported.");
+ return;
+ }
+
+ const bool is_normal_piece =
+ (sp.type() == ModelProto::SentencePiece::NORMAL ||
+ sp.type() == ModelProto::SentencePiece::USER_DEFINED);
+ if (!port::InsertIfNotPresent(
+ is_normal_piece ? &pieces_ : &reserved_id_map_, sp.piece(), i)) {
+ status_ = util::InternalError(sp.piece() + " is already defined.");
+ return;
+ }
+
+ if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
+ }
+}
+
std::vector<StringPiece> SplitIntoWords(StringPiece text) {
const char *begin = text.data();
const char *end = text.data() + text.size();