Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/model_interface.cc27
-rw-r--r--src/model_interface.h8
-rw-r--r--src/model_interface_test.cc48
-rw-r--r--src/sentencepiece_model.proto4
-rw-r--r--src/sentencepiece_processor.cc30
-rw-r--r--src/sentencepiece_processor_test.cc34
-rw-r--r--src/spm_train_main.cc12
-rw-r--r--src/trainer_interface.cc34
-rw-r--r--src/trainer_interface_test.cc30
9 files changed, 192 insertions, 35 deletions
diff --git a/src/model_interface.cc b/src/model_interface.cc
index c69cecc..e46a632 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -20,15 +20,32 @@
namespace sentencepiece {
-const char *ModelInterface::kUNK() { return "<unk>"; }
-const char *ModelInterface::kBOS() { return "<s>"; }
-const char *ModelInterface::kEOS() { return "</s>"; }
-const char *ModelInterface::kPAD() { return "<pad>"; };
-
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto), status_(util::OkStatus()) {}
ModelInterface::~ModelInterface() {}
+#define RETURN_PIECE(name, default_value) \
+ if (model_proto_->trainer_spec().name().empty()) return default_value; \
+ return model_proto_->trainer_spec().name();
+
+absl::string_view ModelInterface::unk_piece() const {
+ RETURN_PIECE(unk_piece, "<unk>");
+}
+
+absl::string_view ModelInterface::bos_piece() const {
+ RETURN_PIECE(bos_piece, "<s>");
+}
+
+absl::string_view ModelInterface::eos_piece() const {
+ RETURN_PIECE(eos_piece, "</s>");
+}
+
+absl::string_view ModelInterface::pad_piece() const {
+ RETURN_PIECE(pad_piece, "<pad>");
+}
+
+#undef RETURN_PIECE
+
int ModelInterface::PieceToId(absl::string_view piece) const {
auto it = reserved_id_map_.find(piece);
if (it != reserved_id_map_.end()) {
diff --git a/src/model_interface.h b/src/model_interface.h
index 320dc86..4a3a44e 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -47,10 +47,10 @@ class ModelInterface {
using PieceToIdMap =
std::unordered_map<absl::string_view, int, string_util::string_view_hash>;
- static const char *kUNK();
- static const char *kBOS();
- static const char *kEOS();
- static const char *kPAD();
+ absl::string_view unk_piece() const;
+ absl::string_view bos_piece() const;
+ absl::string_view eos_piece() const;
+ absl::string_view pad_piece() const;
// `model_proto` should not be deleted until ModelInterface is destroyed.
explicit ModelInterface(const ModelProto &model_proto);
diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc
index a42c931..c1edfc8 100644
--- a/src/model_interface_test.cc
+++ b/src/model_interface_test.cc
@@ -50,6 +50,54 @@ void AddPiece(ModelProto *model_proto, const std::string &piece,
sp->set_score(score);
}
+TEST(ModelInterfaceTest, GetDefaultPieceTest) {
+ {
+ ModelProto model_proto;
+ EXPECT_EQ("<unk>", model_proto.trainer_spec().unk_piece());
+ EXPECT_EQ("<s>", model_proto.trainer_spec().bos_piece());
+ EXPECT_EQ("</s>", model_proto.trainer_spec().eos_piece());
+ EXPECT_EQ("<pad>", model_proto.trainer_spec().pad_piece());
+ }
+
+ {
+ ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
+ AddPiece(&model_proto, "a");
+ auto model = ModelFactory::Create(model_proto);
+ EXPECT_EQ("<unk>", model->unk_piece());
+ EXPECT_EQ("<s>", model->bos_piece());
+ EXPECT_EQ("</s>", model->eos_piece());
+ EXPECT_EQ("<pad>", model->pad_piece());
+ }
+
+ {
+ ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
+ AddPiece(&model_proto, "a");
+ model_proto.mutable_trainer_spec()->clear_unk_piece();
+ model_proto.mutable_trainer_spec()->clear_bos_piece();
+ model_proto.mutable_trainer_spec()->clear_eos_piece();
+ model_proto.mutable_trainer_spec()->clear_pad_piece();
+ auto model = ModelFactory::Create(model_proto);
+ EXPECT_EQ("<unk>", model->unk_piece());
+ EXPECT_EQ("<s>", model->bos_piece());
+ EXPECT_EQ("</s>", model->eos_piece());
+ EXPECT_EQ("<pad>", model->pad_piece());
+ }
+
+ {
+ ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
+ AddPiece(&model_proto, "a");
+ model_proto.mutable_trainer_spec()->set_unk_piece("UNK");
+ model_proto.mutable_trainer_spec()->set_bos_piece("BOS");
+ model_proto.mutable_trainer_spec()->set_eos_piece("EOS");
+ model_proto.mutable_trainer_spec()->set_pad_piece("PAD");
+ auto model = ModelFactory::Create(model_proto);
+ EXPECT_EQ("UNK", model->unk_piece());
+ EXPECT_EQ("BOS", model->bos_piece());
+ EXPECT_EQ("EOS", model->eos_piece());
+ EXPECT_EQ("PAD", model->pad_piece());
+ }
+}
+
TEST(ModelInterfaceTest, SetModelInterfaceTest) {
for (const auto type : kModelTypes) {
ModelProto model_proto = MakeBaseModelProto(type);
diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto
index e34c7be..eedbc40 100644
--- a/src/sentencepiece_model.proto
+++ b/src/sentencepiece_model.proto
@@ -166,6 +166,10 @@ message TrainerSpec {
optional int32 bos_id = 41 [ default = 1 ]; // <s>
optional int32 eos_id = 42 [ default = 2 ]; // </s>
optional int32 pad_id = 43 [ default = -1 ]; // <pad> (padding)
+ optional string unk_piece = 45 [ default = "<unk>" ];
+ optional string bos_piece = 46 [ default = "<s>" ];
+ optional string eos_piece = 47 [ default = "</s>" ];
+ optional string pad_piece = 48 [ default = "<pad>" ];
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 1f425df..8c9c208 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -556,25 +556,25 @@ bool SentencePieceProcessor::IsUnused(int id) const {
}
int SentencePieceProcessor::unk_id() const {
- const int id = PieceToId(ModelInterface::kUNK());
+ const int id = PieceToId(util::min_string_view(model_->unk_piece().data()));
if (IsUnknown(id)) return id;
return -1;
}
int SentencePieceProcessor::bos_id() const {
- const int id = PieceToId(ModelInterface::kBOS());
+ const int id = PieceToId(util::min_string_view(model_->bos_piece().data()));
if (IsControl(id)) return id;
return -1;
}
int SentencePieceProcessor::eos_id() const {
- const int id = PieceToId(ModelInterface::kEOS());
+ const int id = PieceToId(util::min_string_view(model_->eos_piece().data()));
if (IsControl(id)) return id;
return -1;
}
int SentencePieceProcessor::pad_id() const {
- const int id = PieceToId(ModelInterface::kPAD());
+ const int id = PieceToId(util::min_string_view(model_->pad_piece().data()));
if (IsControl(id)) return id;
return -1;
}
@@ -591,8 +591,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
break;
case EOS: {
auto *piece = spt->add_pieces();
- piece->set_id(PieceToId("</s>"));
- piece->set_piece("</s>");
+ piece->set_id(
+ PieceToId(util::min_string_view(model_->eos_piece().data())));
+ piece->set_piece(model_->eos_piece().data(),
+ model_->eos_piece().size());
} break;
case BOS: {
auto *array = spt->mutable_pieces();
@@ -601,8 +603,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
array->SwapElements(i - 1, i);
}
auto *piece = array->Mutable(0);
- piece->set_id(PieceToId("<s>"));
- piece->set_piece("<s>");
+ piece->set_id(
+ PieceToId(util::min_string_view(model_->bos_piece().data())));
+ piece->set_piece(model_->bos_piece().data(),
+ model_->bos_piece().size());
} break;
default:
return util::InternalError("unknown extra_option type.");
@@ -634,12 +638,14 @@ util::Status SentencePieceProcessor::ParseExtraOptions(
extra_options->push_back(it->second);
if (it->second == SentencePieceProcessor::BOS) {
- CHECK_OR_RETURN(!IsUnknown(PieceToId("<s>")))
- << "id for `<s>` is not defined.";
+ CHECK_OR_RETURN(!IsUnknown(
+ PieceToId(util::min_string_view(model_->bos_piece().data()))))
+ << "id for `" << model_->bos_piece() << "` is not defined.";
}
if (it->second == SentencePieceProcessor::EOS) {
- CHECK_OR_RETURN(!IsUnknown(PieceToId("</s>")))
- << "id for `</s>` is not defined.";
+ CHECK_OR_RETURN(!IsUnknown(
+ PieceToId(util::min_string_view(model_->eos_piece().data()))))
+ << "id for `" << model_->eos_piece() << "` is not defined.";
}
}
return util::OkStatus();
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index a313b0d..b60bb8b 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -1054,6 +1054,40 @@ TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) {
EXPECT_NOT_OK(sp.SetDecodeExtraOptions("eos"));
}
+TEST(SentencePieceProcessorTest, OverrideSpecialPieceTest) {
+ ModelProto model_proto;
+ auto *sp1 = model_proto.add_pieces();
+ auto *sp2 = model_proto.add_pieces();
+ auto *sp3 = model_proto.add_pieces();
+
+ model_proto.mutable_trainer_spec()->set_unk_piece("__UNK__");
+ model_proto.mutable_trainer_spec()->set_bos_piece("__BOS__");
+ model_proto.mutable_trainer_spec()->set_eos_piece("__EOS__");
+ model_proto.mutable_trainer_spec()->set_pad_piece("__PAD__");
+
+ // No BOS/EOS.
+ sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
+ sp1->set_piece("__UNK__");
+ sp2->set_type(ModelProto::SentencePiece::CONTROL);
+ sp2->set_piece("__BOS__");
+ sp3->set_type(ModelProto::SentencePiece::CONTROL);
+ sp3->set_piece("__EOS__");
+
+ AddPiece(&model_proto, "a", 0.0);
+ AddPiece(&model_proto, "b", 0.3);
+
+ SentencePieceProcessor sp;
+ EXPECT_OK(sp.Load(model_proto));
+ EXPECT_EQ(0, sp.unk_id());
+ EXPECT_EQ(1, sp.bos_id());
+ EXPECT_EQ(2, sp.eos_id());
+ EXPECT_EQ(-1, sp.pad_id());
+
+ EXPECT_EQ("__UNK__", sp.IdToPiece(sp.unk_id()));
+ EXPECT_EQ("__BOS__", sp.IdToPiece(sp.bos_id()));
+ EXPECT_EQ("__EOS__", sp.IdToPiece(sp.eos_id()));
+}
+
TEST(SentencePieceProcessorTest, VocabularyTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index e1ad01c..9163a30 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -92,6 +92,14 @@ DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
"Override EOS (</s>) id. Set -1 to disable EOS.");
DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
"Override PAD (<pad>) id. Set -1 to disable PAD.");
+DEFINE_string(unk_piece, kDefaultTrainerSpec.unk_piece(),
+ "Override UNK (<unk>) piece.");
+DEFINE_string(bos_piece, kDefaultTrainerSpec.bos_piece(),
+ "Override BOS (<s>) piece.");
+DEFINE_string(eos_piece, kDefaultTrainerSpec.eos_piece(),
+ "Override EOS (</s>) piece.");
+DEFINE_string(pad_piece, kDefaultTrainerSpec.pad_piece(),
+ "Override PAD (<pad>) piece.");
DEFINE_string(unk_surface, kDefaultTrainerSpec.unk_surface(),
"Dummy surface string for <unk>. In decoding <unk> is decoded to "
"`unk_surface`.");
@@ -141,6 +149,10 @@ int main(int argc, char *argv[]) {
SetTrainerSpecFromFlag(bos_id);
SetTrainerSpecFromFlag(eos_id);
SetTrainerSpecFromFlag(pad_id);
+ SetTrainerSpecFromFlag(unk_piece);
+ SetTrainerSpecFromFlag(bos_piece);
+ SetTrainerSpecFromFlag(eos_piece);
+ SetTrainerSpecFromFlag(pad_piece);
SetTrainerSpecFromFlag(unk_surface);
SetRepeatedTrainerSpecFromFlag(input);
SetRepeatedTrainerSpecFromFlag(accept_language);
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index 70c9c39..a82c477 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -71,6 +71,11 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
CHECK_GE_OR_RETURN(trainer_spec.seed_sentencepiece_size(), 1000);
CHECK_GE_OR_RETURN(trainer_spec.training_sentence_size(), 100);
+ CHECK_OR_RETURN(!trainer_spec.unk_piece().empty());
+ CHECK_OR_RETURN(!trainer_spec.bos_piece().empty());
+ CHECK_OR_RETURN(!trainer_spec.eos_piece().empty());
+ CHECK_OR_RETURN(!trainer_spec.pad_piece().empty());
+
return util::OkStatus();
}
} // namespace
@@ -438,21 +443,21 @@ util::Status TrainerInterface::InitMetaPieces() {
if (id < 0) return true;
if (id >= trainer_spec_.vocab_size() ||
meta_pieces_.find(id) != meta_pieces_.end() ||
- (has_unk && w == ModelInterface::kUNK()))
+ (has_unk && w == trainer_spec_.unk_piece()))
return false;
- if (w == ModelInterface::kUNK()) has_unk = true;
+ if (w == trainer_spec_.unk_piece()) has_unk = true;
meta_pieces_[id] = std::make_pair(
- w, w == ModelInterface::kUNK() ? ModelProto::SentencePiece::UNKNOWN
- : ModelProto::SentencePiece::CONTROL);
+ w, w == trainer_spec_.unk_piece() ? ModelProto::SentencePiece::UNKNOWN
+ : ModelProto::SentencePiece::CONTROL);
return true;
};
- CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), ModelInterface::kUNK()));
- CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), ModelInterface::kBOS()));
- CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), ModelInterface::kEOS()));
- CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), ModelInterface::kPAD()));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), trainer_spec_.unk_piece()));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), trainer_spec_.bos_piece()));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), trainer_spec_.eos_piece()));
+ CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), trainer_spec_.pad_piece()));
- CHECK_OR_RETURN(has_unk) << ModelInterface::kUNK() << " must be defined.";
+ CHECK_OR_RETURN(has_unk) << trainer_spec_.unk_piece() << " must be defined.";
std::set<std::string> dup;
@@ -465,17 +470,18 @@ util::Status TrainerInterface::InitMetaPieces() {
return false;
}
- if (w == ModelInterface::kUNK()) {
- LOG(ERROR) << "<unk> must not be defined with --control_symbols and "
+ if (w == trainer_spec_.unk_piece()) {
+ LOG(ERROR) << trainer_spec_.unk_piece()
+ << " must not be defined with --control_symbols and "
"--user_defined_symbols.";
return false;
}
- if (w == ModelInterface::kBOS() && trainer_spec_.bos_id() >= 0) {
+ if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) {
meta_pieces_[trainer_spec_.bos_id()].second = type;
- } else if (w == ModelInterface::kEOS() && trainer_spec_.eos_id() >= 0) {
+ } else if (w == trainer_spec_.eos_piece() && trainer_spec_.eos_id() >= 0) {
meta_pieces_[trainer_spec_.eos_id()].second = type;
- } else if (w == ModelInterface::kPAD() && trainer_spec_.pad_id() >= 0) {
+ } else if (w == trainer_spec_.pad_piece() && trainer_spec_.pad_id() >= 0) {
meta_pieces_[trainer_spec_.pad_id()].second = type;
} else {
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc
index 34feb85..45d4905 100644
--- a/src/trainer_interface_test.cc
+++ b/src/trainer_interface_test.cc
@@ -288,6 +288,36 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
trainer.meta_pieces_[1].second);
}
+
+ {
+ auto trainer_spec = base_trainer_spec;
+ trainer_spec.set_unk_piece("__UNK__");
+ trainer_spec.set_bos_piece("__BOS__");
+ trainer_spec.set_eos_piece("__EOS__");
+ trainer_spec.set_pad_piece("__PAD__");
+ trainer_spec.set_pad_id(3);
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_TRUE(trainer.status().ok());
+ EXPECT_EQ("__UNK__", trainer.meta_pieces_[0].first);
+ EXPECT_EQ("__BOS__", trainer.meta_pieces_[1].first);
+ EXPECT_EQ("__EOS__", trainer.meta_pieces_[2].first);
+ EXPECT_EQ("__PAD__", trainer.meta_pieces_[3].first);
+ }
+
+ {
+ auto trainer_spec = base_trainer_spec;
+ trainer_spec.set_unk_piece("__UNK__");
+ trainer_spec.set_bos_piece("__UNK__");
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_FALSE(trainer.status().ok());
+ }
+
+ {
+ auto trainer_spec = base_trainer_spec;
+ trainer_spec.set_unk_piece("");
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_FALSE(trainer.status().ok());
+ }
}
TEST(TrainerInterfaceTest, SerializeTest) {