diff options
Diffstat (limited to 'src/sentencepiece_processor.cc')
-rw-r--r-- | src/sentencepiece_processor.cc | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1f425df..8c9c208 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -556,25 +556,25 @@ bool SentencePieceProcessor::IsUnused(int id) const { } int SentencePieceProcessor::unk_id() const { - const int id = PieceToId(ModelInterface::kUNK()); + const int id = PieceToId(util::min_string_view(model_->unk_piece().data())); if (IsUnknown(id)) return id; return -1; } int SentencePieceProcessor::bos_id() const { - const int id = PieceToId(ModelInterface::kBOS()); + const int id = PieceToId(util::min_string_view(model_->bos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::eos_id() const { - const int id = PieceToId(ModelInterface::kEOS()); + const int id = PieceToId(util::min_string_view(model_->eos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::pad_id() const { - const int id = PieceToId(ModelInterface::kPAD()); + const int id = PieceToId(util::min_string_view(model_->pad_piece().data())); if (IsControl(id)) return id; return -1; } @@ -591,8 +591,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( break; case EOS: { auto *piece = spt->add_pieces(); - piece->set_id(PieceToId("</s>")); - piece->set_piece("</s>"); + piece->set_id( + PieceToId(util::min_string_view(model_->eos_piece().data()))); + piece->set_piece(model_->eos_piece().data(), + model_->eos_piece().size()); } break; case BOS: { auto *array = spt->mutable_pieces(); @@ -601,8 +603,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( array->SwapElements(i - 1, i); } auto *piece = array->Mutable(0); - piece->set_id(PieceToId("<s>")); - piece->set_piece("<s>"); + piece->set_id( + PieceToId(util::min_string_view(model_->bos_piece().data()))); + piece->set_piece(model_->bos_piece().data(), + model_->bos_piece().size()); } break; default: return util::InternalError("unknown extra_option type."); @@ -634,12 +638,14 @@ util::Status SentencePieceProcessor::ParseExtraOptions( extra_options->push_back(it->second); if (it->second == SentencePieceProcessor::BOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId("<s>"))) - << "id for `<s>` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->bos_piece().data())))) + << "id for `" << model_->bos_piece() << "` is not defined."; } if (it->second == SentencePieceProcessor::EOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId("</s>"))) - << "id for `</s>` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->eos_piece().data())))) + << "id for `" << model_->eos_piece() << "` is not defined."; } } return util::OkStatus(); |