From 5f635d0892debe2001ea889f0cf02185449fcaec Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sat, 8 Dec 2018 22:08:05 +0900 Subject: support to change the piece of unk/bos/eos/pad --- src/model_interface.cc | 27 +++++++++++++++++---- src/model_interface.h | 8 +++---- src/model_interface_test.cc | 48 +++++++++++++++++++++++++++++++++++++ src/sentencepiece_model.proto | 4 ++++ src/sentencepiece_processor.cc | 30 +++++++++++++---------- src/sentencepiece_processor_test.cc | 34 ++++++++++++++++++++++++++ src/spm_train_main.cc | 12 ++++++++++ src/trainer_interface.cc | 34 +++++++++++++++----------- src/trainer_interface_test.cc | 30 +++++++++++++++++++++++ 9 files changed, 192 insertions(+), 35 deletions(-) (limited to 'src') diff --git a/src/model_interface.cc b/src/model_interface.cc index c69cecc..e46a632 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -20,15 +20,32 @@ namespace sentencepiece { -const char *ModelInterface::kUNK() { return ""; } -const char *ModelInterface::kBOS() { return ""; } -const char *ModelInterface::kEOS() { return ""; } -const char *ModelInterface::kPAD() { return ""; }; - ModelInterface::ModelInterface(const ModelProto &model_proto) : model_proto_(&model_proto), status_(util::OkStatus()) {} ModelInterface::~ModelInterface() {} +#define RETURN_PIECE(name, default_value) \ + if (model_proto_->trainer_spec().name().empty()) return default_value; \ + return model_proto_->trainer_spec().name(); + +absl::string_view ModelInterface::unk_piece() const { + RETURN_PIECE(unk_piece, ""); +} + +absl::string_view ModelInterface::bos_piece() const { + RETURN_PIECE(bos_piece, ""); +} + +absl::string_view ModelInterface::eos_piece() const { + RETURN_PIECE(eos_piece, ""); +} + +absl::string_view ModelInterface::pad_piece() const { + RETURN_PIECE(pad_piece, ""); +} + +#undef RETURN_PIECE + int ModelInterface::PieceToId(absl::string_view piece) const { auto it = reserved_id_map_.find(piece); if (it != reserved_id_map_.end()) { diff --git a/src/model_interface.h b/src/model_interface.h index 320dc86..4a3a44e 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -47,10 +47,10 @@ class ModelInterface { using PieceToIdMap = std::unordered_map; - static const char *kUNK(); - static const char *kBOS(); - static const char *kEOS(); - static const char *kPAD(); + absl::string_view unk_piece() const; + absl::string_view bos_piece() const; + absl::string_view eos_piece() const; + absl::string_view pad_piece() const; // `model_proto` should not be deleted until ModelInterface is destroyed. explicit ModelInterface(const ModelProto &model_proto); diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index a42c931..c1edfc8 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -50,6 +50,54 @@ void AddPiece(ModelProto *model_proto, const std::string &piece, sp->set_score(score); } +TEST(ModelInterfaceTest, GetDefaultPieceTest) { + { + ModelProto model_proto; + EXPECT_EQ("", model_proto.trainer_spec().unk_piece()); + EXPECT_EQ("", model_proto.trainer_spec().bos_piece()); + EXPECT_EQ("", model_proto.trainer_spec().eos_piece()); + EXPECT_EQ("", model_proto.trainer_spec().pad_piece()); + } + + { + ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); + AddPiece(&model_proto, "a"); + auto model = ModelFactory::Create(model_proto); + EXPECT_EQ("", model->unk_piece()); + EXPECT_EQ("", model->bos_piece()); + EXPECT_EQ("", model->eos_piece()); + EXPECT_EQ("", model->pad_piece()); + } + + { + ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); + AddPiece(&model_proto, "a"); + model_proto.mutable_trainer_spec()->clear_unk_piece(); + model_proto.mutable_trainer_spec()->clear_bos_piece(); + model_proto.mutable_trainer_spec()->clear_eos_piece(); + model_proto.mutable_trainer_spec()->clear_pad_piece(); + auto model = ModelFactory::Create(model_proto); + EXPECT_EQ("", model->unk_piece()); + EXPECT_EQ("", model->bos_piece()); + EXPECT_EQ("", model->eos_piece()); + EXPECT_EQ("", model->pad_piece()); + } + + { + ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); + AddPiece(&model_proto, "a"); + model_proto.mutable_trainer_spec()->set_unk_piece("UNK"); + model_proto.mutable_trainer_spec()->set_bos_piece("BOS"); + model_proto.mutable_trainer_spec()->set_eos_piece("EOS"); + model_proto.mutable_trainer_spec()->set_pad_piece("PAD"); + auto model = ModelFactory::Create(model_proto); + EXPECT_EQ("UNK", model->unk_piece()); + EXPECT_EQ("BOS", model->bos_piece()); + EXPECT_EQ("EOS", model->eos_piece()); + EXPECT_EQ("PAD", model->pad_piece()); + } +} + TEST(ModelInterfaceTest, SetModelInterfaceTest) { for (const auto type : kModelTypes) { ModelProto model_proto = MakeBaseModelProto(type); diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index e34c7be..eedbc40 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -166,6 +166,10 @@ message TrainerSpec { optional int32 bos_id = 41 [ default = 1 ]; // optional int32 eos_id = 42 [ default = 2 ]; // optional int32 pad_id = 43 [ default = -1 ]; // (padding) + optional string unk_piece = 45 [ default = "" ]; + optional string bos_piece = 46 [ default = "" ]; + optional string eos_piece = 47 [ default = "" ]; + optional string pad_piece = 48 [ default = "" ]; // Encodes into U+2047 (DOUBLE QUESTION MARK), // since this character can be useful both for user and diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1f425df..8c9c208 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -556,25 +556,25 @@ bool SentencePieceProcessor::IsUnused(int id) const { } int SentencePieceProcessor::unk_id() const { - const int id = PieceToId(ModelInterface::kUNK()); + const int id = PieceToId(util::min_string_view(model_->unk_piece().data())); if (IsUnknown(id)) return id; return -1; } int SentencePieceProcessor::bos_id() const { - const int id = PieceToId(ModelInterface::kBOS()); + const int id = PieceToId(util::min_string_view(model_->bos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::eos_id() const { - const int id = PieceToId(ModelInterface::kEOS()); + const int id = PieceToId(util::min_string_view(model_->eos_piece().data())); if (IsControl(id)) return id; return -1; } int SentencePieceProcessor::pad_id() const { - const int id = PieceToId(ModelInterface::kPAD()); + const int id = PieceToId(util::min_string_view(model_->pad_piece().data())); if (IsControl(id)) return id; return -1; } @@ -591,8 +591,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( break; case EOS: { auto *piece = spt->add_pieces(); - piece->set_id(PieceToId("")); - piece->set_piece(""); + piece->set_id( + PieceToId(util::min_string_view(model_->eos_piece().data()))); + piece->set_piece(model_->eos_piece().data(), + model_->eos_piece().size()); } break; case BOS: { auto *array = spt->mutable_pieces(); @@ -601,8 +603,10 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( array->SwapElements(i - 1, i); } auto *piece = array->Mutable(0); - piece->set_id(PieceToId("")); - piece->set_piece(""); + piece->set_id( + PieceToId(util::min_string_view(model_->bos_piece().data()))); + piece->set_piece(model_->bos_piece().data(), + model_->bos_piece().size()); } break; default: return util::InternalError("unknown extra_option type."); @@ -634,12 +638,14 @@ util::Status SentencePieceProcessor::ParseExtraOptions( extra_options->push_back(it->second); if (it->second == SentencePieceProcessor::BOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId(""))) - << "id for `` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->bos_piece().data())))) + << "id for `" << model_->bos_piece() << "` is not defined."; } if (it->second == SentencePieceProcessor::EOS) { - CHECK_OR_RETURN(!IsUnknown(PieceToId(""))) - << "id for `` is not defined."; + CHECK_OR_RETURN(!IsUnknown( + PieceToId(util::min_string_view(model_->eos_piece().data())))) + << "id for `" << model_->eos_piece() << "` is not defined."; } } return util::OkStatus(); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index a313b0d..b60bb8b 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -1054,6 +1054,40 @@ TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) { EXPECT_NOT_OK(sp.SetDecodeExtraOptions("eos")); } +TEST(SentencePieceProcessorTest, OverrideSpecialPieceTest) { + ModelProto model_proto; + auto *sp1 = model_proto.add_pieces(); + auto *sp2 = model_proto.add_pieces(); + auto *sp3 = model_proto.add_pieces(); + + model_proto.mutable_trainer_spec()->set_unk_piece("__UNK__"); + model_proto.mutable_trainer_spec()->set_bos_piece("__BOS__"); + model_proto.mutable_trainer_spec()->set_eos_piece("__EOS__"); + model_proto.mutable_trainer_spec()->set_pad_piece("__PAD__"); + + // No BOS/EOS. + sp1->set_type(ModelProto::SentencePiece::UNKNOWN); + sp1->set_piece("__UNK__"); + sp2->set_type(ModelProto::SentencePiece::CONTROL); + sp2->set_piece("__BOS__"); + sp3->set_type(ModelProto::SentencePiece::CONTROL); + sp3->set_piece("__EOS__"); + + AddPiece(&model_proto, "a", 0.0); + AddPiece(&model_proto, "b", 0.3); + + SentencePieceProcessor sp; + EXPECT_OK(sp.Load(model_proto)); + EXPECT_EQ(0, sp.unk_id()); + EXPECT_EQ(1, sp.bos_id()); + EXPECT_EQ(2, sp.eos_id()); + EXPECT_EQ(-1, sp.pad_id()); + + EXPECT_EQ("__UNK__", sp.IdToPiece(sp.unk_id())); + EXPECT_EQ("__BOS__", sp.IdToPiece(sp.bos_id())); + EXPECT_EQ("__EOS__", sp.IdToPiece(sp.eos_id())); +} + TEST(SentencePieceProcessorTest, VocabularyTest) { ModelProto model_proto; auto *sp1 = model_proto.add_pieces(); diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index e1ad01c..9163a30 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -92,6 +92,14 @@ DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(), "Override EOS () id. Set -1 to disable EOS."); DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(), "Override PAD () id. Set -1 to disable PAD."); +DEFINE_string(unk_piece, kDefaultTrainerSpec.unk_piece(), + "Override UNK () piece."); +DEFINE_string(bos_piece, kDefaultTrainerSpec.bos_piece(), + "Override BOS () piece."); +DEFINE_string(eos_piece, kDefaultTrainerSpec.eos_piece(), + "Override EOS () piece."); +DEFINE_string(pad_piece, kDefaultTrainerSpec.pad_piece(), + "Override PAD () piece."); DEFINE_string(unk_surface, kDefaultTrainerSpec.unk_surface(), "Dummy surface string for . In decoding is decoded to " "`unk_surface`."); @@ -141,6 +149,10 @@ int main(int argc, char *argv[]) { SetTrainerSpecFromFlag(bos_id); SetTrainerSpecFromFlag(eos_id); SetTrainerSpecFromFlag(pad_id); + SetTrainerSpecFromFlag(unk_piece); + SetTrainerSpecFromFlag(bos_piece); + SetTrainerSpecFromFlag(eos_piece); + SetTrainerSpecFromFlag(pad_piece); SetTrainerSpecFromFlag(unk_surface); SetRepeatedTrainerSpecFromFlag(input); SetRepeatedTrainerSpecFromFlag(accept_language); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 70c9c39..a82c477 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -71,6 +71,11 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) { CHECK_GE_OR_RETURN(trainer_spec.seed_sentencepiece_size(), 1000); CHECK_GE_OR_RETURN(trainer_spec.training_sentence_size(), 100); + CHECK_OR_RETURN(!trainer_spec.unk_piece().empty()); + CHECK_OR_RETURN(!trainer_spec.bos_piece().empty()); + CHECK_OR_RETURN(!trainer_spec.eos_piece().empty()); + CHECK_OR_RETURN(!trainer_spec.pad_piece().empty()); + return util::OkStatus(); } } // namespace @@ -438,21 +443,21 @@ util::Status TrainerInterface::InitMetaPieces() { if (id < 0) return true; if (id >= trainer_spec_.vocab_size() || meta_pieces_.find(id) != meta_pieces_.end() || - (has_unk && w == ModelInterface::kUNK())) + (has_unk && w == trainer_spec_.unk_piece())) return false; - if (w == ModelInterface::kUNK()) has_unk = true; + if (w == trainer_spec_.unk_piece()) has_unk = true; meta_pieces_[id] = std::make_pair( - w, w == ModelInterface::kUNK() ? ModelProto::SentencePiece::UNKNOWN - : ModelProto::SentencePiece::CONTROL); + w, w == trainer_spec_.unk_piece() ? ModelProto::SentencePiece::UNKNOWN + : ModelProto::SentencePiece::CONTROL); return true; }; - CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), ModelInterface::kUNK())); - CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), ModelInterface::kBOS())); - CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), ModelInterface::kEOS())); - CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), ModelInterface::kPAD())); + CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), trainer_spec_.unk_piece())); + CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), trainer_spec_.bos_piece())); + CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), trainer_spec_.eos_piece())); + CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), trainer_spec_.pad_piece())); - CHECK_OR_RETURN(has_unk) << ModelInterface::kUNK() << " must be defined."; + CHECK_OR_RETURN(has_unk) << trainer_spec_.unk_piece() << " must be defined."; std::set dup; @@ -465,17 +470,18 @@ util::Status TrainerInterface::InitMetaPieces() { return false; } - if (w == ModelInterface::kUNK()) { - LOG(ERROR) << " must not be defined with --control_symbols and " + if (w == trainer_spec_.unk_piece()) { + LOG(ERROR) << trainer_spec_.unk_piece() + << " must not be defined with --control_symbols and " "--user_defined_symbols."; return false; } - if (w == ModelInterface::kBOS() && trainer_spec_.bos_id() >= 0) { + if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) { meta_pieces_[trainer_spec_.bos_id()].second = type; - } else if (w == ModelInterface::kEOS() && trainer_spec_.eos_id() >= 0) { + } else if (w == trainer_spec_.eos_piece() && trainer_spec_.eos_id() >= 0) { meta_pieces_[trainer_spec_.eos_id()].second = type; - } else if (w == ModelInterface::kPAD() && trainer_spec_.pad_id() >= 0) { + } else if (w == trainer_spec_.pad_piece() && trainer_spec_.pad_id() >= 0) { meta_pieces_[trainer_spec_.pad_id()].second = type; } else { while (meta_pieces_.find(id) != meta_pieces_.end()) ++id; diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index 34feb85..45d4905 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -288,6 +288,36 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) { EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED, trainer.meta_pieces_[1].second); } + + { + auto trainer_spec = base_trainer_spec; + trainer_spec.set_unk_piece("__UNK__"); + trainer_spec.set_bos_piece("__BOS__"); + trainer_spec.set_eos_piece("__EOS__"); + trainer_spec.set_pad_piece("__PAD__"); + trainer_spec.set_pad_id(3); + TrainerInterface trainer(trainer_spec, normalizer_spec); + EXPECT_TRUE(trainer.status().ok()); + EXPECT_EQ("__UNK__", trainer.meta_pieces_[0].first); + EXPECT_EQ("__BOS__", trainer.meta_pieces_[1].first); + EXPECT_EQ("__EOS__", trainer.meta_pieces_[2].first); + EXPECT_EQ("__PAD__", trainer.meta_pieces_[3].first); + } + + { + auto trainer_spec = base_trainer_spec; + trainer_spec.set_unk_piece("__UNK__"); + trainer_spec.set_bos_piece("__UNK__"); + TrainerInterface trainer(trainer_spec, normalizer_spec); + EXPECT_FALSE(trainer.status().ok()); + } + + { + auto trainer_spec = base_trainer_spec; + trainer_spec.set_unk_piece(""); + TrainerInterface trainer(trainer_spec, normalizer_spec); + EXPECT_FALSE(trainer.status().ok()); + } } TEST(TrainerInterfaceTest, SerializeTest) { -- cgit v1.2.3