Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-12-08 16:08:05 +0300
committerTaku Kudo <taku@google.com>2018-12-08 16:08:05 +0300
commit5f635d0892debe2001ea889f0cf02185449fcaec (patch)
treee84d1651a56341eef9cd89c9bdb46a878a952de4 /src/sentencepiece_processor.cc
parentf6229bfad55f28e6b35e6860d9c06a5ea9bd83c9 (diff)
support to change the piece of unk/bos/eos/pad
Diffstat (limited to 'src/sentencepiece_processor.cc')
-rw-r--r--src/sentencepiece_processor.cc30
1 files changed, 18 insertions, 12 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 1f425df..8c9c208 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -556,25 +556,25 @@ bool SentencePieceProcessor::IsUnused(int id) const {
}
int SentencePieceProcessor::unk_id() const {
- const int id = PieceToId(ModelInterface::kUNK());
+ const int id = PieceToId(util::min_string_view(model_->unk_piece().data()));
if (IsUnknown(id)) return id;
return -1;
}
int SentencePieceProcessor::bos_id() const {
- const int id = PieceToId(ModelInterface::kBOS());
+ const int id = PieceToId(util::min_string_view(model_->bos_piece().data()));
if (IsControl(id)) return id;
return -1;
}
int SentencePieceProcessor::eos_id() const {
- const int id = PieceToId(ModelInterface::kEOS());
+ const int id = PieceToId(util::min_string_view(model_->eos_piece().data()));
if (IsControl(id)) return id;
return -1;
}
int SentencePieceProcessor::pad_id() const {
- const int id = PieceToId(ModelInterface::kPAD());
+ const int id = PieceToId(util::min_string_view(model_->pad_piece().data()));
if (IsControl(id)) return id;
return -1;
}
@@ -591,8 +591,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
break;
case EOS: {
auto *piece = spt->add_pieces();
- piece->set_id(PieceToId("</s>"));
- piece->set_piece("</s>");
+ piece->set_id(
+ PieceToId(util::min_string_view(model_->eos_piece().data())));
+ piece->set_piece(model_->eos_piece().data(),
+ model_->eos_piece().size());
} break;
case BOS: {
auto *array = spt->mutable_pieces();
@@ -601,8 +603,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
array->SwapElements(i - 1, i);
}
auto *piece = array->Mutable(0);
- piece->set_id(PieceToId("<s>"));
- piece->set_piece("<s>");
+ piece->set_id(
+ PieceToId(util::min_string_view(model_->bos_piece().data())));
+ piece->set_piece(model_->bos_piece().data(),
+ model_->bos_piece().size());
} break;
default:
return util::InternalError("unknown extra_option type.");
@@ -634,12 +638,14 @@ util::Status SentencePieceProcessor::ParseExtraOptions(
extra_options->push_back(it->second);
if (it->second == SentencePieceProcessor::BOS) {
- CHECK_OR_RETURN(!IsUnknown(PieceToId("<s>")))
- << "id for `<s>` is not defined.";
+ CHECK_OR_RETURN(!IsUnknown(
+ PieceToId(util::min_string_view(model_->bos_piece().data()))))
+ << "id for `" << model_->bos_piece() << "` is not defined.";
}
if (it->second == SentencePieceProcessor::EOS) {
- CHECK_OR_RETURN(!IsUnknown(PieceToId("</s>")))
- << "id for `</s>` is not defined.";
+ CHECK_OR_RETURN(!IsUnknown(
+ PieceToId(util::min_string_view(model_->eos_piece().data()))))
+ << "id for `" << model_->eos_piece() << "` is not defined.";
}
}
return util::OkStatus();