diff options
author | Taku Kudo <taku@google.com> | 2018-06-04 14:32:37 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-06-04 14:32:37 +0300 |
commit | 4e3bcf1373fb7c8ddca151ddf2a6f10914057cfa (patch) | |
tree | 71ea36b440216b25cfa7290be2a09d3ac82b7acb /src/model_interface.cc | |
parent | 4f7af0dfadbf547264296d46924055842c901b60 (diff) |
Updated normalizer
Diffstat (limited to 'src/model_interface.cc')
-rw-r--r-- | src/model_interface.cc | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc index 62ecf17..b76d84b 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -57,14 +57,14 @@ bool ModelInterface::IsUnknown(int id) const { void ModelInterface::InitializePieces(bool enable_user_defined) { pieces_.clear(); reserved_id_map_.clear(); - unk_id_ = 0; + unk_id_ = -1; for (int i = 0; i < model_proto_->pieces_size(); ++i) { const auto &sp = model_proto_->pieces(i); if (!enable_user_defined && sp.type() == ModelProto::SentencePiece::USER_DEFINED) { status_ = util::StatusBuilder(util::error::INTERNAL) - << "user defined symbol is not supported."; + << "User defined symbol is not supported."; return; } @@ -78,8 +78,19 @@ void ModelInterface::InitializePieces(bool enable_user_defined) { return; } - if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i; + if (sp.type() == ModelProto::SentencePiece::UNKNOWN) { + if (unk_id_ >= 0) { + status_ = util::StatusBuilder(util::error::INTERNAL) + << "unk is already defined."; + return; + } + unk_id_ = i; + } } + + if (unk_id_ == -1) + status_ = util::StatusBuilder(util::error::INTERNAL) + << "unk is not defined."; } std::vector<StringPiece> SplitIntoWords(StringPiece text) { |