diff options
author | Taku Kudo <taku@google.com> | 2018-12-08 16:08:05 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2018-12-08 16:08:05 +0300 |
commit | 5f635d0892debe2001ea889f0cf02185449fcaec (patch) | |
tree | e84d1651a56341eef9cd89c9bdb46a878a952de4 /src/sentencepiece_processor.cc | |
parent | f6229bfad55f28e6b35e6860d9c06a5ea9bd83c9 (diff) |
support to change the piece of unk/bos/eos/pad
Diffstat (limited to 'src/sentencepiece_processor.cc')
-rw-r--r-- | src/sentencepiece_processor.cc | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1f425df..8c9c208 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -556,25 +556,25 @@ bool SentencePieceProcessor::IsUnused(int id) const { } int SentencePieceProcessor::unk_id() const { - const int id = PieceToId(ModelInterface::kUNK()); + const int id = PieceToId(util::min_string_view(model_->unk_piece().data())); if (IsUnknown(id)) return id; return -1; } int SentencePieceProcessor::bos_id() const { - const int id = PieceToId(ModelInterface::kBOS()); + const int id = PieceToId(util::min_string_view(model_->bos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::eos_id() const { - const int id = PieceToId(ModelInterface::kEOS()); + const int id = PieceToId(util::min_string_view(model_->eos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::pad_id() const { - const int id = PieceToId(ModelInterface::kPAD()); + const int id = PieceToId(util::min_string_view(model_->pad_piece().data())); if (IsControl(id)) return id; return -1; } @@ -591,8 +591,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( break; case EOS: { auto *piece = spt->add_pieces(); - piece->set_id(PieceToId("</s>")); - piece->set_piece("</s>"); + piece->set_id( + PieceToId(util::min_string_view(model_->eos_piece().data()))); + piece->set_piece(model_->eos_piece().data(), + model_->eos_piece().size()); } break; case BOS: { auto *array = spt->mutable_pieces(); @@ -601,8 +603,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( array->SwapElements(i - 1, i); } auto *piece = array->Mutable(0); - piece->set_id(PieceToId("<s>")); - piece->set_piece("<s>"); + piece->set_id( + PieceToId(util::min_string_view(model_->bos_piece().data()))); + piece->set_piece(model_->bos_piece().data(), + model_->bos_piece().size()); } break; default: return util::InternalError("unknown extra_option type."); @@ -634,12 +638,14 @@ util::Status SentencePieceProcessor::ParseExtraOptions( extra_options->push_back(it->second); if (it->second == SentencePieceProcessor::BOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId("<s>"))) - << "id for `<s>` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->bos_piece().data())))) + << "id for `" << model_->bos_piece() << "` is not defined."; } if (it->second == SentencePieceProcessor::EOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId("</s>"))) - << "id for `</s>` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->eos_piece().data())))) + << "id for `" << model_->eos_piece() << "` is not defined."; } } return util::OkStatus(); |