Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku910@users.noreply.github.com>2018-04-09 12:44:57 +0300
committerGitHub <noreply@github.com>2018-04-09 12:44:57 +0300
commitecbd55ac54c37184699628da5f148f9ce0770297 (patch)
tree0e41c4ac3fd6946fdcb8a39c072381db766661f0 /src
parentf75834c16703a8057eddb93f39ef7d075faccb16 (diff)
parentd1028974960d9e7ac9b408f6c212aa90d7c958cb (diff)
Merge pull request #53 from google/sr
Support to change ids of <unk>, <s>, </s>
Diffstat (limited to 'src')
-rw-r--r--src/bpe_model.cc2
-rw-r--r--src/bpe_model_trainer.cc5
-rw-r--r--src/char_model.cc2
-rw-r--r--src/char_model_trainer.cc6
-rw-r--r--src/model_interface.cc21
-rw-r--r--src/model_interface.h7
-rw-r--r--src/sentencepiece_model.proto15
-rw-r--r--src/sentencepiece_processor.cc16
-rw-r--r--src/sentencepiece_trainer.cc15
-rw-r--r--src/trainer_interface.cc105
-rw-r--r--src/trainer_interface.h14
-rw-r--r--src/trainer_interface_test.cc99
-rw-r--r--src/unigram_model.cc7
-rw-r--r--src/unigram_model_trainer.cc6
-rw-r--r--src/word_model.cc2
-rw-r--r--src/word_model_trainer.cc5
16 files changed, 239 insertions, 88 deletions
diff --git a/src/bpe_model.cc b/src/bpe_model.cc
index fdc5fd7..fe9df73 100644
--- a/src/bpe_model.cc
+++ b/src/bpe_model.cc
@@ -22,7 +22,6 @@ namespace bpe {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
- CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
@@ -35,6 +34,7 @@ Model::Model(const ModelProto &model_proto) {
LOG(FATAL) << "User defined symbol is not supported in BPE";
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
+ if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
}
}
}
diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc
index 62892ae..e68e0b5 100644
--- a/src/bpe_model_trainer.cc
+++ b/src/bpe_model_trainer.cc
@@ -211,11 +211,8 @@ void Trainer::Train() {
}
}
- const int meta_symbols_size = trainer_spec_.control_symbols().size() +
- trainer_spec_.user_defined_symbols().size() +
- 3; // <s>, </s>, <unk>
const int vocab_size =
- trainer_spec_.vocab_size() - meta_symbols_size - required_chars_.size();
+ trainer_spec_.vocab_size() - meta_pieces_.size() - required_chars_.size();
CHECK_GE(vocab_size, 0);
// We may see duplicated pieces that are extracted with different path.
diff --git a/src/char_model.cc b/src/char_model.cc
index 56df8dc..e11cbed 100644
--- a/src/char_model.cc
+++ b/src/char_model.cc
@@ -20,7 +20,6 @@ namespace character {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
- CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
@@ -31,6 +30,7 @@ Model::Model(const ModelProto &model_proto) {
port::InsertOrDie(&pieces_, sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
+ if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
}
}
}
diff --git a/src/char_model_trainer.cc b/src/char_model_trainer.cc
index 202a15b..f6836ac 100644
--- a/src/char_model_trainer.cc
+++ b/src/char_model_trainer.cc
@@ -40,11 +40,7 @@ void Trainer::Train() {
LoadSentences();
- const int meta_symbols_size = trainer_spec_.control_symbols().size() +
- trainer_spec_.user_defined_symbols().size() +
- 3; // <s>, </s>, <unk>
-
- const int vocab_size = trainer_spec_.vocab_size() - meta_symbols_size;
+ const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size();
CHECK_GE(vocab_size, 0);
uint64 sum = 0;
diff --git a/src/model_interface.cc b/src/model_interface.cc
index 05c25d5..d4602ea 100644
--- a/src/model_interface.cc
+++ b/src/model_interface.cc
@@ -18,8 +18,6 @@
namespace sentencepiece {
-const uint32 ModelInterface::kUnkID = 0;
-
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto) {}
ModelInterface::~ModelInterface() {}
@@ -33,7 +31,7 @@ int ModelInterface::PieceToId(StringPiece piece) const {
if (it2 != pieces_.end()) {
return it2->second;
}
- return kUnkID;
+ return unk_id_;
}
int ModelInterface::GetPieceSize() const {
@@ -58,23 +56,6 @@ bool ModelInterface::IsUnknown(int id) const {
ModelProto::SentencePiece::UNKNOWN);
}
-void ModelInterface::CheckControlSymbols() const {
- CHECK_NOTNULL(model_proto_);
-
- CHECK_GE(model_proto_->pieces_size(), 3); // <unk>, <s>, </s>
-
- // Verify reserved control symbols and unknon symbol.
- CHECK_EQ(ModelProto::SentencePiece::UNKNOWN, // <unk>
- model_proto_->pieces(0).type());
- CHECK_EQ("<unk>", model_proto_->pieces(0).piece());
- CHECK_EQ(ModelProto::SentencePiece::CONTROL, // <s>
- model_proto_->pieces(1).type());
- CHECK_EQ("<s>", model_proto_->pieces(1).piece());
- CHECK_EQ(ModelProto::SentencePiece::CONTROL, // </s>
- model_proto_->pieces(2).type());
- CHECK_EQ("</s>", model_proto_->pieces(2).piece());
-}
-
std::vector<StringPiece> SplitIntoWords(StringPiece text) {
const char *begin = text.data();
const char *end = text.data() + text.size();
diff --git a/src/model_interface.h b/src/model_interface.h
index b472bca..add033b 100644
--- a/src/model_interface.h
+++ b/src/model_interface.h
@@ -39,8 +39,6 @@ class ModelInterface {
public:
using PieceToIdMap = std::unordered_map<StringPiece, int, StringPieceHash>;
- static const uint32 kUnkID;
-
// |model_proto| should not be deleted until ModelInterface is destroyed.
explicit ModelInterface(const ModelProto &model_proto);
ModelInterface() {}
@@ -89,8 +87,6 @@ class ModelInterface {
virtual bool IsControl(int id) const;
protected:
- void CheckControlSymbols() const;
-
const ModelProto *model_proto_ = nullptr;
// piece -> id map for normal pieces
@@ -98,6 +94,9 @@ class ModelInterface {
// piece -> id map for control and unknown
PieceToIdMap reserved_id_map_;
+
+ // unknown id.
+ int unk_id_ = 0;
};
} // namespace sentencepiece
#endif // MODEL_INTERFACE_H_
diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto
index c262a6e..3cd73f5 100644
--- a/src/sentencepiece_model.proto
+++ b/src/sentencepiece_model.proto
@@ -31,6 +31,11 @@ message TrainerSpec {
// with this parameter.
repeated string input = 1;
+ // Input corpus format:
+ // "text": one-sentence-per-line text format (default)
+ // "tsv": sentence <tab> freq
+ optional string input_format = 7;
+
// Output model file prefix.
// <model_prefix>.model and <model_prefix>.vocab are generated.
optional string model_prefix = 2;
@@ -127,6 +132,16 @@ message TrainerSpec {
// Typical usage of user_defined_symbols is placeholder for named entities.
repeated string user_defined_symbols = 31;
+ ///////////////////////////////////////////////////////////////////
+ // Reserved special meta tokens.
+ // * -1 is not used.
+ // * unk_id must not be -1.
+ // Id must starts with 0 and be contigous.
+ optional int32 unk_id = 40 [ default = 0 ]; // <unk>
+ optional int32 bos_id = 41 [ default = 1 ]; // <s>
+ optional int32 eos_id = 42 [ default = 2 ]; // </s>
+ optional int32 pad_id = 43 [ default = -1 ]; // <pad> (padding)
+
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 0879b6e..612de21 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -383,9 +383,6 @@ bool SentencePieceProcessor::IsUnknown(int id) const {
void SentencePieceProcessor::ApplyExtraOptions(
const std::vector<ExtraOption> &extra_options,
SentencePieceText *spt) const {
- constexpr int kBOS = 1;
- constexpr int kEOS = 2;
-
for (const auto &extra_option : extra_options) {
switch (extra_option) {
case REVERSE:
@@ -394,9 +391,9 @@ void SentencePieceProcessor::ApplyExtraOptions(
break;
case EOS: {
auto *piece = spt->add_pieces();
- piece->set_id(kEOS);
- piece->set_piece(IdToPiece(kEOS));
- } break;
+ piece->set_id(PieceToId("</s>"));
+ piece->set_piece("</s>");
+ } break;
case BOS: {
auto *array = spt->mutable_pieces();
array->Add();
@@ -404,16 +401,17 @@ void SentencePieceProcessor::ApplyExtraOptions(
array->SwapElements(i - 1, i);
}
auto *piece = array->Mutable(0);
- piece->set_id(kBOS);
- piece->set_piece(IdToPiece(kBOS));
+ piece->set_id(PieceToId("<s>"));
+ piece->set_piece("<s>");
} break;
default:
LOG(FATAL) << "Unknown extra_option type: "
<< static_cast<int>(extra_option);
- }
+ }
}
}
+
// static
std::vector<SentencePieceProcessor::ExtraOption>
SentencePieceProcessor::ParseExtraOptions(const std::string &extra_option) {
diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc
index 0f2873f..ad7d8ea 100644
--- a/src/sentencepiece_trainer.cc
+++ b/src/sentencepiece_trainer.cc
@@ -31,6 +31,8 @@ static const sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
DEFINE_string(input, "", "comma separated list of input sentences");
+DEFINE_string(input_format, kDefaultTrainerSpec.input_format(),
+ "Input format. Supported format is `text` or `tsv`.");
DEFINE_string(model_prefix, "", "output model prefix");
DEFINE_string(model_type, "unigram",
"model algorithm: unigram, bpe, word or char");
@@ -76,6 +78,14 @@ DEFINE_bool(remove_extra_whitespaces,
kDefaultNormalizerSpec.remove_extra_whitespaces(),
"Removes leading, trailing, and "
"duplicate internal whitespace");
+DEFINE_int32(unk_id, kDefaultTrainerSpec.unk_id(),
+ "Override UNK (<unk>) id.");
+DEFINE_int32(bos_id, kDefaultTrainerSpec.bos_id(),
+ "Override BOS (<s>) id. Set -1 to disable BOS.");
+DEFINE_int32(eos_id, kDefaultTrainerSpec.eos_id(),
+ "Override EOS (</s>) id. Set -1 to disable EOS.");
+DEFINE_int32(pad_id, kDefaultTrainerSpec.pad_id(),
+ "Override PAD (<pad>) id. Set -1 to disable PAD.");
using sentencepiece::NormalizerSpec;
using sentencepiece::TrainerSpec;
@@ -128,6 +138,7 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
CHECK_NOTNULL(trainer_spec);
CHECK_NOTNULL(normalizer_spec);
+ SetTrainerSpecFromFlag(input_format);
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
SetTrainerSpecFromFlag(character_coverage);
@@ -141,6 +152,10 @@ void MakeTrainerSpecFromFlags(TrainerSpec *trainer_spec,
SetTrainerSpecFromFlag(max_sentencepiece_length);
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
+ SetTrainerSpecFromFlag(unk_id);
+ SetTrainerSpecFromFlag(bos_id);
+ SetTrainerSpecFromFlag(eos_id);
+ SetTrainerSpecFromFlag(pad_id);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc
index a633144..c19743f 100644
--- a/src/trainer_interface.cc
+++ b/src/trainer_interface.cc
@@ -14,6 +14,7 @@
#include "trainer_interface.h"
+#include <cstdlib>
#include <memory>
#include <string>
#include <unordered_map>
@@ -37,9 +38,16 @@ const char TrainerInterface::kUNKStr[] = "\xe2\x96\x85";
const char32 TrainerInterface::kUPPBoundaryChar = L'\u0009';
const char TrainerInterface::kUPPBoundaryStr[] = "\t";
+const char TrainerInterface::kUNK[] = "<unk>";
+const char TrainerInterface::kBOS[] = "<s>";
+const char TrainerInterface::kEOS[] = "</s>";
+const char TrainerInterface::kPAD[] = "<pad>";
+
TrainerInterface::TrainerInterface(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
- : trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {}
+ : trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {
+ InitMetaPieces();
+}
TrainerInterface::~TrainerInterface() {}
bool TrainerInterface::IsValidSentencePiece(
@@ -105,11 +113,28 @@ void TrainerInterface::LoadSentences() {
const normalizer::Normalizer normalizer(normalizer_spec_);
+ CHECK(trainer_spec_.input_format().empty() ||
+ trainer_spec_.input_format() == "text" ||
+ trainer_spec_.input_format() == "tsv")
+ << "Supported formats are 'text' and 'tsv'.";
+
+ const bool is_tsv = trainer_spec_.input_format() == "tsv";
+
for (const auto &filename : trainer_spec_.input()) {
LOG(INFO) << "Loading corpus: " << filename;
std::string sentence;
io::InputBuffer input(filename);
while (input.ReadLine(&sentence)) {
+ int64 freq = 1;
+ if (is_tsv) {
+ const std::vector<std::string> v = string_util::Split(sentence, "\t");
+ CHECK_EQ(v.size(), 2)
+ << "Input format must be: word <tab> freq. " << sentence;
+ sentence = v[0];
+ freq = std::atoll(v[1].c_str());
+ CHECK_GE(freq, 1);
+ }
+
constexpr int kMaxLines = 2048;
if (sentence.size() > kMaxLines) {
continue;
@@ -133,9 +158,7 @@ void TrainerInterface::LoadSentences() {
continue;
}
- // TODO(taku): We assumes that the sentence frequency is always 1.
- // Support to use sentences with frequencies.
- sentences_.emplace_back(normalized, 1);
+ sentences_.emplace_back(normalized, freq);
if (sentences_.size() ==
static_cast<size_t>(trainer_spec_.input_sentence_size())) {
@@ -196,13 +219,14 @@ END:
w.first = string_util::UnicodeTextToUTF8(uw2);
}
- // +3 for <unk>, <s>, </s>
+ // +3 for meta pieces.
if (trainer_spec_.model_type() != TrainerSpec::WORD &&
trainer_spec_.model_type() != TrainerSpec::CHAR) {
- CHECK_LT(static_cast<int>(required_chars_.size() + 3),
+ CHECK_LT(static_cast<int>(required_chars_.size() + meta_pieces_.size()),
trainer_spec_.vocab_size())
<< "Vocabulary size is smaller than required_chars. "
- << trainer_spec_.vocab_size() << " vs " << required_chars_.size() + 3
+ << trainer_spec_.vocab_size() << " vs "
+ << required_chars_.size() + meta_pieces_.size()
<< ". "
<< "Increase vocab_size or decrease character_coverage with "
<< "--character_coverage option.";
@@ -234,30 +258,12 @@ void TrainerInterface::Serialize(ModelProto *model_proto) const {
CHECK(dup.insert(piece).second) << piece << " is already defined";
};
- auto *unk = model_proto->add_pieces();
- unk->set_piece("<unk>");
- unk->set_type(ModelProto::SentencePiece::UNKNOWN);
- CheckPiece(unk->piece());
-
- for (const auto &w : {"<s>", "</s>"}) {
+ for (const auto &w : meta_pieces_) {
auto *sp = model_proto->add_pieces();
- sp->set_piece(w);
- sp->set_type(ModelProto::SentencePiece::CONTROL);
- CheckPiece(sp->piece());
- }
-
- for (const auto &w : trainer_spec_.control_symbols()) {
- auto *sp = model_proto->add_pieces();
- sp->set_piece(w);
- sp->set_type(ModelProto::SentencePiece::CONTROL);
- CheckPiece(sp->piece());
- }
-
- for (const auto &w : trainer_spec_.user_defined_symbols()) {
- auto *sp = model_proto->add_pieces();
- sp->set_piece(w);
- sp->set_type(ModelProto::SentencePiece::USER_DEFINED);
+ sp->set_piece(w.first);
+ sp->set_type(w.second);
sp->set_score(0.0);
+ CHECK_NE(ModelProto::SentencePiece::NORMAL, sp->type());
CheckPiece(sp->piece());
}
@@ -304,6 +310,45 @@ void TrainerInterface::SaveVocab(StringPiece filename) const {
void TrainerInterface::Save() const {
SaveModel(trainer_spec_.model_prefix() + ".model");
SaveVocab(trainer_spec_.model_prefix() + ".vocab");
- // SaveSplits(trainer_spec_.model_prefix() + ".splits");
}
+
+void TrainerInterface::InitMetaPieces() {
+ CHECK(meta_pieces_.empty());
+
+ std::vector<std::pair<int, std::string>> ids;
+ if (trainer_spec_.unk_id() >= 0)
+ ids.emplace_back(trainer_spec_.unk_id(), kUNK);
+ if (trainer_spec_.bos_id() >= 0)
+ ids.emplace_back(trainer_spec_.bos_id(), kBOS);
+ if (trainer_spec_.eos_id() >= 0)
+ ids.emplace_back(trainer_spec_.eos_id(), kEOS);
+ if (trainer_spec_.pad_id() >= 0)
+ ids.emplace_back(trainer_spec_.pad_id(), kPAD);
+
+ std::sort(ids.begin(), ids.end());
+
+ int prev_id = -1;
+ bool has_unk = false;
+ for (const auto &p : ids) {
+ CHECK_EQ(prev_id + 1, p.first)
+ << "ID for `" << p.second << "` must be " << prev_id + 1;
+ prev_id = p.first;
+ CHECK_EQ(static_cast<int>(meta_pieces_.size()), p.first);
+ if (p.second == kUNK) has_unk = true;
+ meta_pieces_.emplace_back(
+ p.second, (p.second == kUNK ? ModelProto::SentencePiece::UNKNOWN
+ : ModelProto::SentencePiece::CONTROL));
+ }
+
+ CHECK(has_unk) << kUNK << " must be defined.";
+
+ for (const auto &w : trainer_spec_.control_symbols()) {
+ meta_pieces_.emplace_back(w, ModelProto::SentencePiece::CONTROL);
+ }
+
+ for (const auto &w : trainer_spec_.user_defined_symbols()) {
+ meta_pieces_.emplace_back(w, ModelProto::SentencePiece::USER_DEFINED);
+ }
+}
+
} // namespace sentencepiece
diff --git a/src/trainer_interface.h b/src/trainer_interface.h
index 67c3bf5..240c73b 100644
--- a/src/trainer_interface.h
+++ b/src/trainer_interface.h
@@ -55,6 +55,11 @@ class TrainerInterface {
static const char kUNKStr[];
static const char kUPPBoundaryStr[];
+ static const char kUNK[];
+ static const char kBOS[];
+ static const char kEOS[];
+ static const char kPAD[];
+
TrainerInterface(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec);
@@ -63,6 +68,7 @@ class TrainerInterface {
virtual void Train() {}
FRIEND_TEST(TrainerInterfaceTest, IsValidSentencePieceTest);
+ FRIEND_TEST(TrainerInterfaceTest, OverrideSpecialPieces);
protected:
// Returns true if |piece| is valid sentence piece.
@@ -100,6 +106,11 @@ class TrainerInterface {
// Normalizer spec
NormalizerSpec normalizer_spec_;
+ // Reserved control pieces. e.g., <unk>, <s>, </s>.
+ // The index corresponds to vocab id.
+ std::vector<std::pair<std::string,
+ ModelProto::SentencePiece::Type>> meta_pieces_;
+
private:
// Serialize final_pieces_ to |model_proto|.
void Serialize(ModelProto *model_proto) const;
@@ -112,6 +123,9 @@ class TrainerInterface {
// Saves vocabulary file for NMT.
void SaveVocab(StringPiece filename) const;
+
+ // Initializes `meta_pieces_` from TrainerSpec.
+ void InitMetaPieces();
};
} // namespace sentencepiece
#endif // TRAINER_INTERFACE_H_
diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc
index 6eb3daa..d07a48b 100644
--- a/src/trainer_interface_test.cc
+++ b/src/trainer_interface_test.cc
@@ -73,4 +73,103 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_TRUE(IsValid("1234"));
EXPECT_FALSE(IsValid("12345"));
}
+
+TEST(TrainerInterfaceTest, OverrideSpecialPieces) {
+ TrainerSpec trainer_spec;
+ NormalizerSpec normalizer_spec;
+
+ // Check default values.
+ EXPECT_EQ(0, trainer_spec.unk_id());
+ EXPECT_EQ(1, trainer_spec.bos_id());
+ EXPECT_EQ(2, trainer_spec.eos_id());
+ EXPECT_EQ(-1, trainer_spec.pad_id());
+
+ {
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(1);
+ trainer_spec.set_eos_id(2);
+ trainer_spec.set_pad_id(3);
+
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_EQ(4, trainer.meta_pieces_.size());
+ EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
+ EXPECT_EQ("<s>", trainer.meta_pieces_[1].first);
+ EXPECT_EQ("</s>", trainer.meta_pieces_[2].first);
+ EXPECT_EQ("<pad>", trainer.meta_pieces_[3].first);
+ }
+
+ {
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(3);
+ trainer_spec.set_eos_id(2);
+ trainer_spec.set_pad_id(1);
+
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_EQ(4, trainer.meta_pieces_.size());
+ EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
+ EXPECT_EQ("<pad>", trainer.meta_pieces_[1].first);
+ EXPECT_EQ("</s>", trainer.meta_pieces_[2].first);
+ EXPECT_EQ("<s>", trainer.meta_pieces_[3].first);
+ }
+
+ {
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(-1);
+ trainer_spec.set_eos_id(1);
+ trainer_spec.set_pad_id(-1);
+
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_EQ(2, trainer.meta_pieces_.size());
+ EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
+ EXPECT_EQ("</s>", trainer.meta_pieces_[1].first);
+ }
+
+ {
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(-1);
+ trainer_spec.set_eos_id(-1);
+ trainer_spec.set_pad_id(-1);
+
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_EQ(1, trainer.meta_pieces_.size());
+ EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
+ }
+
+ {
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(1);
+ trainer_spec.set_eos_id(2);
+ trainer_spec.set_pad_id(-1);
+
+ trainer_spec.add_control_symbols("<c1>");
+ trainer_spec.add_control_symbols("<c2>");
+ trainer_spec.add_user_defined_symbols("<u1>");
+ trainer_spec.add_user_defined_symbols("<u2>");
+
+ TrainerInterface trainer(trainer_spec, normalizer_spec);
+ EXPECT_EQ(7, trainer.meta_pieces_.size());
+ EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
+ EXPECT_EQ("<s>", trainer.meta_pieces_[1].first);
+ EXPECT_EQ("</s>", trainer.meta_pieces_[2].first);
+ EXPECT_EQ("<c1>", trainer.meta_pieces_[3].first);
+ EXPECT_EQ("<c2>", trainer.meta_pieces_[4].first);
+ EXPECT_EQ("<u1>", trainer.meta_pieces_[5].first);
+ EXPECT_EQ("<u2>", trainer.meta_pieces_[6].first);
+ }
+
+ {
+ // ID is not contiguous.
+ trainer_spec.set_unk_id(0);
+ trainer_spec.set_bos_id(-1);
+ trainer_spec.set_eos_id(2);
+ EXPECT_DEATH(TrainerInterface trainer(trainer_spec, normalizer_spec));
+
+ // UNK is not defined.
+ trainer_spec.set_unk_id(-1);
+ trainer_spec.set_bos_id(0);
+ trainer_spec.set_eos_id(1);
+ EXPECT_DEATH(TrainerInterface trainer(trainer_spec, normalizer_spec));
+ }
+}
+
} // namespace sentencepiece
diff --git a/src/unigram_model.cc b/src/unigram_model.cc
index 3107c61..075a9dc 100644
--- a/src/unigram_model.cc
+++ b/src/unigram_model.cc
@@ -424,7 +424,7 @@ void ModelBase::PopulateNodes(Lattice *lattice) const {
if (!has_single_node) {
Lattice::Node *node = lattice->Insert(begin_pos, 1);
- node->id = kUnkID; // add UNK node.
+ node->id = unk_id_; // add UNK node.
node->score = unk_score;
}
}
@@ -437,7 +437,7 @@ int ModelBase::PieceToId(StringPiece piece) const {
}
int id = 0;
trie_->exactMatchSearch(piece.data(), id);
- return id == -1 ? kUnkID : id;
+ return id == -1 ? unk_id_ : id;
}
void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) {
@@ -478,8 +478,6 @@ Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
min_score_ = FLT_MAX;
- CheckControlSymbols();
-
std::vector<std::pair<std::string, int>> pieces; // <piece, vocab_id>
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
@@ -490,6 +488,7 @@ Model::Model(const ModelProto &model_proto) {
pieces.emplace_back(sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
+ if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
}
if (sp.type() == ModelProto::SentencePiece::NORMAL) {
min_score_ = std::min(min_score_, sp.score());
diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc
index d9dd199..b022f93 100644
--- a/src/unigram_model_trainer.cc
+++ b/src/unigram_model_trainer.cc
@@ -448,11 +448,7 @@ TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
}
}
- const int meta_symbols_size = trainer_spec_.control_symbols().size() +
- trainer_spec_.user_defined_symbols().size() +
- 3; // <s>, </s>, <unk>
-
- const int vocab_size_size = trainer_spec_.vocab_size() - meta_symbols_size;
+ const int vocab_size_size = trainer_spec_.vocab_size() - meta_pieces_.size();
CHECK_GT(vocab_size_size, 0);
// Then keeps sentencepieces with higher scores.
diff --git a/src/word_model.cc b/src/word_model.cc
index 66fdcd2..4972746 100644
--- a/src/word_model.cc
+++ b/src/word_model.cc
@@ -21,7 +21,6 @@ namespace word {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
- CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
@@ -32,6 +31,7 @@ Model::Model(const ModelProto &model_proto) {
port::InsertOrDie(&pieces_, sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
+ if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
}
}
}
diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc
index 0fee4cc..2892012 100644
--- a/src/word_model_trainer.cc
+++ b/src/word_model_trainer.cc
@@ -47,11 +47,8 @@ void Trainer::Train() {
freq[s.to_string()] += it.second;
}
}
- const int meta_symbols_size = trainer_spec_.control_symbols().size() +
- trainer_spec_.user_defined_symbols().size() +
- 3; // <s>, </s>, <unk>
- const int vocab_size = trainer_spec_.vocab_size() - meta_symbols_size;
+ const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size();
CHECK_GE(vocab_size, 0);
uint64 sum = 0;